Skip to content

Commit

Permalink
fix typeguard integration
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599235762
  • Loading branch information
Qwlouse authored and The kauldron Authors committed Jan 17, 2024
1 parent c8101b8 commit 86849d3
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 41 deletions.
4 changes: 2 additions & 2 deletions kauldron/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
UInt8,
)
from kauldron.typing.shape_spec import Memo, Shape # pylint: disable=g-multiple-import,g-importing-member
from kauldron.typing.type_check import TypeCheckError, typechecked # pylint: disable=g-multiple-import,g-importing-member
from kauldron.typing.type_check import TypeCheckError, check_type, typechecked # pylint: disable=g-multiple-import,g-importing-member
import numpy as np
import typeguard as _typeguard
# make typeguard.check_type accessible in this namespace
check_type = _typeguard.check_type
# check_type = _typeguard.check_type

PRNGKey = UInt32["2"]
PRNGKeyLike = Union[int, Sequence[int], np.ndarray, PRNGKey]
Expand Down
5 changes: 4 additions & 1 deletion kauldron/typing/shape_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
import math
import operator
import typing
from typing import Any, List, Optional, Callable
from typing import Any, Callable, List, Optional

import jaxtyping
import lark
import tensorflow as tf # TODO(klausg): do not depend on this


if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -288,6 +289,8 @@ def from_current_context(cls):

def _maybe_remove_bool(memo):
match memo:
case bool(_), tf.TensorShape():
return tuple(d for d in memo[1])
case (bool(_), (*dims,)) if all(isinstance(d, int) for d in dims):
return tuple(dims)
case (*dims,) if all(isinstance(d, int) for d in dims):
Expand Down
185 changes: 147 additions & 38 deletions kauldron/typing/type_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

from __future__ import annotations

import collections.abc
import dataclasses
import functools
import inspect
import re
import sys
import types
from typing import Any, Type, Union
import typing
from typing import Any, ForwardRef, Optional, Type, Union

from etils import enp
import jax
Expand All @@ -38,7 +40,89 @@
_undef = object()


class TypeCheckError(typeguard.TypeCheckError):
# typing.get_origin(Dict[int, int]) -> dict
# typing.get_origin(Mapping[int, int]) -> collections.abc.Mapping
MAPPING_TYPES = (dict, collections.abc.Mapping, collections.abc.MutableMapping)
SEQUENCE_TYPES = (list, collections.abc.Sequence)


def check_type(
value: Any,
expected_type: Any,
*,
argname: str = "value",
memo: Optional[typeguard._TypeCheckMemo] = None,
) -> None:
"""Runtime check if value conforms to an expected type."""
if memo is None:
frame = sys._getframe(1) # pylint: disable=protected-access
memo = typeguard._TypeCheckMemo( # pylint: disable=protected-access
globals=frame.f_globals, locals=frame.f_locals
)

# first do the normal typecheck
typeguard.check_type(
argname,
value,
expected_type,
memo=memo,
)

# then if that worked, redo using our own check_type
_check_type_internal(
value,
expected_type,
argname=argname,
memo=memo,
)


def _check_type_internal(
value: Any,
expected_type: Any,
*,
argname: str = "value",
memo: typeguard._TypeCheckMemo,
):
"""Internal typechecking that takes care of checking shape annotations."""

if isinstance(expected_type, ForwardRef):
try:
expected_type = expected_type._evaluate( # pylint: disable=protected-access
memo.globals, memo.locals, recursive_guard=set()
)
except Exception: # pylint: disable=broad-exception-caught
# ignore failed evaluations of forward refs
pass

origin_type = typing.get_origin(expected_type)

if origin_type in (Union, types.UnionType):
# TODO(klausg): handle Union of ArrayType with other types
args = expected_type.__args__
if all(_is_array_type(arg) for arg in args):
return custom_array_type_union_checker(value, args, argname=argname)
# recurse into mappings, sequences etc.
if origin_type in MAPPING_TYPES: # mapping
if hasattr(expected_type, "__args__") and len(expected_type.__args__) == 2:
_, value_type = expected_type.__args__
for k, v in value.items():
_check_type_internal(v, value_type, argname=f"{argname}.{k}", memo=memo)
if origin_type in SEQUENCE_TYPES:
if hasattr(expected_type, "__args__") and len(expected_type.__args__) == 1:
value_type = expected_type.__args__[0]
for i, v in enumerate(value):
_check_type_internal(
v, value_type, argname=f"{argname}[{i}]", memo=memo
)
if typing.is_typeddict(expected_type):
for k, v_type in expected_type.__annotations__.items():
_check_type_internal(
value[k], v_type, argname=f"{argname}.{k}", memo=memo
)


class TypeCheckError(TypeError): # pylint: disable=g-bad-exception-name
"""Indicates a runtime typechecking error from the @typechecked decorator."""

def __init__(
Expand Down Expand Up @@ -98,14 +182,14 @@ def _reraise_with_shape_info(*args, _typecheck: bool = True, **kwargs):
# manually reproduce the functionality of typeguard.typechecked, so that
# we get access to the returnvalue of the function
localns = sys._getframe(1).f_locals # pylint: disable=protected-access
memo = typeguard.CallMemo(python_func, localns, args=args, kwargs=kwargs)
memo = typeguard._CallMemo(python_func, localns, args=args, kwargs=kwargs) # pylint: disable=protected-access
retval = _undef
try:
typeguard.check_argument_types(memo)
_check_argument_types(memo)
retval = fn(*args, **kwargs)
typeguard.check_return_type(retval, memo)
_check_return_type(retval, memo)
return retval
except typeguard.TypeCheckError as e:
except TypeError as e:
# Use function signature to construct a complete list of named arguments
sig = inspect.signature(fn)
bound_args = sig.bind(*args, **kwargs)
Expand Down Expand Up @@ -243,12 +327,11 @@ def fail_message(self) -> str:

def custom_array_type_union_checker(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: typeguard.TypeCheckMemo,
*,
argname: str = "value",
) -> None:
"""Custom checker for typeguard to better support Array type annotations."""
del origin_type, memo
individual_matches = [ArraySpecMatch(value, arg) for arg in args]
correct_matches = [m.all_correct for m in individual_matches]
if any(correct_matches):
Expand All @@ -262,8 +345,9 @@ def custom_array_type_union_checker(
# first check if any of the array types matches
if not any(m.type_correct for m in individual_matches):
acceptable_array_types = {arg.array_type for arg in args}
raise typeguard.TypeCheckError(
f"was of type {type(value)} which is none of {acceptable_array_types}"
raise TypeError(
f"{argname} was of type {type(value)} which is none of"
f" {acceptable_array_types}"
)

# then check if any of the dtypes matches
Expand All @@ -274,8 +358,9 @@ def custom_array_type_union_checker(
options_str = f"any of {acceptable_dtypes}"
else:
options_str = f"{acceptable_dtypes[0]}"
raise typeguard.TypeCheckError(
f"was {value_spec_str} which is not dtype-compatible with {options_str}"
raise TypeError(
f"{argname} was {value_spec_str} which is not dtype-compatible with"
f" {options_str}"
)
# then check if any of the shapes matches
if not any(m.shape_correct for m in individual_matches):
Expand All @@ -284,17 +369,19 @@ def custom_array_type_union_checker(
options_str = f"any of {acceptable_shapes}"
else:
options_str = f"'{acceptable_shapes[0]}'"
raise typeguard.TypeCheckError(
f"was {value_spec_str} which is not shape-compatible with {options_str}"
raise TypeError(
f"{argname} was {value_spec_str} which is not shape-compatible with"
f" {options_str}"
)

# None of the three factors alone fail, but a combination of them does.
# That means we compile a list of interesting failures:
fail_messages = "\n".join(
" - " + m.fail_message() for m in individual_matches if m.is_interesting
)
raise typeguard.TypeCheckError(
f"was {value_spec_str} which did not match any of:\n{fail_messages}"
raise TypeError(
f"{argname} was {value_spec_str} which did not match any"
f" of:\n{fail_messages}"
)


Expand Down Expand Up @@ -360,24 +447,46 @@ def array_spec_checker_lookup(
return None


def add_custom_checker_lookup_fn(lookup_fn):
"""Add custom array spec checker lookup function to typeguard."""
# Add custom array spec checker lookup function to typguard
# check not for equality but for qualname, to avoid many copies when
# reloading modules from colab
if hasattr(typeguard, "checker_lookup_functions"):
# Recent `typeguard` has different API
checker_lookup_fns = typeguard.checker_lookup_functions
else:
# TODO(epot): Remove once typeguard is updated
checker_lookup_fns = typeguard.config.checker_lookup_functions
for i, f in enumerate(checker_lookup_fns):
if f.__qualname__ == lookup_fn.__qualname__:
# replace
checker_lookup_fns[i : i + 1] = [lookup_fn]
break
else: # prepend
checker_lookup_fns[:0] = [lookup_fn]


add_custom_checker_lookup_fn(array_spec_checker_lookup)
def _check_argument_types(memo: typeguard._CallMemo) -> bool:
"""Check that the argument values match the annotated types."""
# raise TypeError("check_argument_shape_annotations not implemented")
for argname, expected_type in memo.type_hints.items():
if argname != "return" and argname in memo.arguments:
value = memo.arguments[argname]
description = 'argument "{}"'.format(argname)
try:
check_type(value, expected_type, argname=description)
except TypeError as exc: # suppress unnecessarily long tracebacks
raise TypeError(*exc.args) from None

return True


def _check_return_type(retval: Any, memo: typeguard._CallMemo) -> bool:
"""Internal implementation of checking the return value of a func call."""
if "return" in memo.type_hints:
if memo.type_hints["return"] is typeguard.NoReturn:
raise TypeError(
"{}() was declared never to return but it did".format(memo.func_name)
)
try:
check_type(
retval,
memo.type_hints["return"],
argname="the return value",
memo=memo,
)
except TypeError as exc: # suppress unnecessarily long tracebacks
# Allow NotImplemented if this is a binary magic method (__eq__() et al)
if retval is NotImplemented and memo.type_hints["return"] is bool:
# This does (and cannot) not check if it's actually a method
func_name = memo.func_name.rsplit(".", 1)[-1]
if (
len(memo.arguments) == 2
and func_name in typeguard.BINARY_MAGIC_METHODS
):
return True

raise TypeError(*exc.args) from None

return True

0 comments on commit 86849d3

Please sign in to comment.