diff --git a/kauldron/typing/__init__.py b/kauldron/typing/__init__.py index ff31f79f..23e43ca2 100644 --- a/kauldron/typing/__init__.py +++ b/kauldron/typing/__init__.py @@ -38,11 +38,11 @@ UInt8, ) from kauldron.typing.shape_spec import Memo, Shape # pylint: disable=g-multiple-import,g-importing-member -from kauldron.typing.type_check import TypeCheckError, typechecked # pylint: disable=g-multiple-import,g-importing-member +from kauldron.typing.type_check import TypeCheckError, check_type, typechecked # pylint: disable=g-multiple-import,g-importing-member import numpy as np import typeguard as _typeguard # make typeguard.check_type accessible in this namespace -check_type = _typeguard.check_type +# check_type = _typeguard.check_type PRNGKey = UInt32["2"] PRNGKeyLike = Union[int, Sequence[int], np.ndarray, PRNGKey] diff --git a/kauldron/typing/shape_spec.py b/kauldron/typing/shape_spec.py index 5eb02b1b..406a85b3 100644 --- a/kauldron/typing/shape_spec.py +++ b/kauldron/typing/shape_spec.py @@ -23,10 +23,11 @@ import math import operator import typing -from typing import Any, List, Optional, Callable +from typing import Any, Callable, List, Optional import jaxtyping import lark +import tensorflow as tf # TODO(klausg): do not depend on this if typing.TYPE_CHECKING: @@ -288,6 +289,8 @@ def from_current_context(cls): def _maybe_remove_bool(memo): match memo: + case bool(_), tf.TensorShape(): + return tuple(d for d in memo[1]) case (bool(_), (*dims,)) if all(isinstance(d, int) for d in dims): return tuple(dims) case (*dims,) if all(isinstance(d, int) for d in dims): diff --git a/kauldron/typing/type_check.py b/kauldron/typing/type_check.py index 8b3c89e2..ceb1c6f0 100644 --- a/kauldron/typing/type_check.py +++ b/kauldron/typing/type_check.py @@ -16,13 +16,15 @@ from __future__ import annotations +import collections.abc import dataclasses import functools import inspect import re import sys import types -from typing import Any, Type, Union +import typing +from typing import Any, ForwardRef, Optional, Type, Union from etils import enp import jax @@ -38,7 +40,89 @@ _undef = object() -class TypeCheckError(typeguard.TypeCheckError): +# typing.get_origin(Dict[int, int]) -> dict +# typing.get_origin(Mapping[int, int]) -> collections.abc.Mapping +MAPPING_TYPES = (dict, collections.abc.Mapping, collections.abc.MutableMapping) +SEQUENCE_TYPES = (list, collections.abc.Sequence) + + +def check_type( + value: Any, + expected_type: Any, + *, + argname: str = "value", + memo: Optional[typeguard._TypeCheckMemo] = None, +) -> None: + """Runtime check if value conforms to an expected type.""" + if memo is None: + frame = sys._getframe(1) # pylint: disable=protected-access + memo = typeguard._TypeCheckMemo( # pylint: disable=protected-access + globals=frame.f_globals, locals=frame.f_locals + ) + + # first do the normal typecheck + typeguard.check_type( + argname, + value, + expected_type, + memo=memo, + ) + + # then if that worked, redo using our own check_type + _check_type_internal( + value, + expected_type, + argname=argname, + memo=memo, + ) + + +def _check_type_internal( + value: Any, + expected_type: Any, + *, + argname: str = "value", + memo: typeguard._TypeCheckMemo, +): + """Internal typechecking that takes care of checking shape annotations.""" + + if isinstance(expected_type, ForwardRef): + try: + expected_type = expected_type._evaluate( # pylint: disable=protected-access + memo.globals, memo.locals, recursive_guard=set() + ) + except Exception: # pylint: disable=broad-exception-caught + # ignore failed evaluations of forward refs + pass + + origin_type = typing.get_origin(expected_type) + + if origin_type in (Union, types.UnionType): + # TODO(klausg): handle Union of ArrayType with other types + args = expected_type.__args__ + if all(_is_array_type(arg) for arg in args): + return custom_array_type_union_checker(value, args, argname=argname) + # recurse into mappings, sequences etc. + if origin_type in MAPPING_TYPES: # mapping + if hasattr(expected_type, "__args__") and len(expected_type.__args__) == 2: + _, value_type = expected_type.__args__ + for k, v in value.items(): + _check_type_internal(v, value_type, argname=f"{argname}.{k}", memo=memo) + if origin_type in SEQUENCE_TYPES: + if hasattr(expected_type, "__args__") and len(expected_type.__args__) == 1: + value_type = expected_type.__args__[0] + for i, v in enumerate(value): + _check_type_internal( + v, value_type, argname=f"{argname}[{i}]", memo=memo + ) + if typing.is_typeddict(expected_type): + for k, v_type in expected_type.__annotations__.items(): + _check_type_internal( + value[k], v_type, argname=f"{argname}.{k}", memo=memo + ) + + +class TypeCheckError(TypeError): # pylint: disable=g-bad-exception-name """Indicates a runtime typechecking error from the @typechecked decorator.""" def __init__( @@ -98,14 +182,14 @@ def _reraise_with_shape_info(*args, _typecheck: bool = True, **kwargs): # manually reproduce the functionality of typeguard.typechecked, so that # we get access to the returnvalue of the function localns = sys._getframe(1).f_locals # pylint: disable=protected-access - memo = typeguard.CallMemo(python_func, localns, args=args, kwargs=kwargs) + memo = typeguard._CallMemo(python_func, localns, args=args, kwargs=kwargs) # pylint: disable=protected-access retval = _undef try: - typeguard.check_argument_types(memo) + _check_argument_types(memo) retval = fn(*args, **kwargs) - typeguard.check_return_type(retval, memo) + _check_return_type(retval, memo) return retval - except typeguard.TypeCheckError as e: + except TypeError as e: # Use function signature to construct a complete list of named arguments sig = inspect.signature(fn) bound_args = sig.bind(*args, **kwargs) @@ -243,12 +327,11 @@ def fail_message(self) -> str: def custom_array_type_union_checker( value: Any, - origin_type: Any, args: tuple[Any, ...], - memo: typeguard.TypeCheckMemo, + *, + argname: str = "value", ) -> None: """Custom checker for typeguard to better support Array type annotations.""" - del origin_type, memo individual_matches = [ArraySpecMatch(value, arg) for arg in args] correct_matches = [m.all_correct for m in individual_matches] if any(correct_matches): @@ -262,8 +345,9 @@ def custom_array_type_union_checker( # first check if any of the array types matches if not any(m.type_correct for m in individual_matches): acceptable_array_types = {arg.array_type for arg in args} - raise typeguard.TypeCheckError( - f"was of type {type(value)} which is none of {acceptable_array_types}" + raise TypeError( + f"{argname} was of type {type(value)} which is none of" + f" {acceptable_array_types}" ) # then check if any of the dtypes matches @@ -274,8 +358,9 @@ def custom_array_type_union_checker( options_str = f"any of {acceptable_dtypes}" else: options_str = f"{acceptable_dtypes[0]}" - raise typeguard.TypeCheckError( - f"was {value_spec_str} which is not dtype-compatible with {options_str}" + raise TypeError( + f"{argname} was {value_spec_str} which is not dtype-compatible with" + f" {options_str}" ) # then check if any of the shapes matches if not any(m.shape_correct for m in individual_matches): @@ -284,8 +369,9 @@ def custom_array_type_union_checker( options_str = f"any of {acceptable_shapes}" else: options_str = f"'{acceptable_shapes[0]}'" - raise typeguard.TypeCheckError( - f"was {value_spec_str} which is not shape-compatible with {options_str}" + raise TypeError( + f"{argname} was {value_spec_str} which is not shape-compatible with" + f" {options_str}" ) # None of the three factors alone fail, but a combination of them does. @@ -293,8 +379,9 @@ def custom_array_type_union_checker( fail_messages = "\n".join( " - " + m.fail_message() for m in individual_matches if m.is_interesting ) - raise typeguard.TypeCheckError( - f"was {value_spec_str} which did not match any of:\n{fail_messages}" + raise TypeError( + f"{argname} was {value_spec_str} which did not match any" + f" of:\n{fail_messages}" ) @@ -360,24 +447,46 @@ def array_spec_checker_lookup( return None -def add_custom_checker_lookup_fn(lookup_fn): - """Add custom array spec checker lookup function to typeguard.""" - # Add custom array spec checker lookup function to typguard - # check not for equality but for qualname, to avoid many copies when - # reloading modules from colab - if hasattr(typeguard, "checker_lookup_functions"): - # Recent `typeguard` has different API - checker_lookup_fns = typeguard.checker_lookup_functions - else: - # TODO(epot): Remove once typeguard is updated - checker_lookup_fns = typeguard.config.checker_lookup_functions - for i, f in enumerate(checker_lookup_fns): - if f.__qualname__ == lookup_fn.__qualname__: - # replace - checker_lookup_fns[i : i + 1] = [lookup_fn] - break - else: # prepend - checker_lookup_fns[:0] = [lookup_fn] - - -add_custom_checker_lookup_fn(array_spec_checker_lookup) +def _check_argument_types(memo: typeguard._CallMemo) -> bool: + """Check that the argument values match the annotated types.""" + # raise TypeError("check_argument_shape_annotations not implemented") + for argname, expected_type in memo.type_hints.items(): + if argname != "return" and argname in memo.arguments: + value = memo.arguments[argname] + description = 'argument "{}"'.format(argname) + try: + check_type(value, expected_type, argname=description) + except TypeError as exc: # suppress unnecessarily long tracebacks + raise TypeError(*exc.args) from None + + return True + + +def _check_return_type(retval: Any, memo: typeguard._CallMemo) -> bool: + """Internal implementation of checking the return value of a func call.""" + if "return" in memo.type_hints: + if memo.type_hints["return"] is typeguard.NoReturn: + raise TypeError( + "{}() was declared never to return but it did".format(memo.func_name) + ) + try: + check_type( + retval, + memo.type_hints["return"], + argname="the return value", + memo=memo, + ) + except TypeError as exc: # suppress unnecessarily long tracebacks + # Allow NotImplemented if this is a binary magic method (__eq__() et al) + if retval is NotImplemented and memo.type_hints["return"] is bool: + # This does (and cannot) not check if it's actually a method + func_name = memo.func_name.rsplit(".", 1)[-1] + if ( + len(memo.arguments) == 2 + and func_name in typeguard.BINARY_MAGIC_METHODS + ): + return True + + raise TypeError(*exc.args) from None + + return True