diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 035be0baf..1ed106d59 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -87,7 +87,9 @@ jobs: - name: Run mypy run: python -m mypy thinc --no-implicit-reexport - if: matrix.python_version != '3.6' + if: | + matrix.python_version != '3.6' && + matrix.python_version != '3.7' - name: Delete source directory run: rm -rf thinc diff --git a/requirements.txt b/requirements.txt index b7682e738..3e3c9901e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ pytest-cov>=2.7.0,<5.0.0 coverage>=5.0.0,<8.0.0 mock>=2.0.0,<3.0.0 flake8>=3.5.0,<3.6.0 -mypy>=1.0.0,<1.1.0; python_version >= "3.7" +mypy>=1.5.0,<1.6.0; platform_machine != "aarch64" and python_version >= "3.8" types-mock>=0.1.1 types-contextvars>=0.1.2; python_version < "3.7" types-dataclasses>=0.1.3; python_version < "3.7" diff --git a/thinc/backends/cupy_ops.py b/thinc/backends/cupy_ops.py index 1e1e5b92b..472b6c542 100644 --- a/thinc/backends/cupy_ops.py +++ b/thinc/backends/cupy_ops.py @@ -1,7 +1,7 @@ import numpy from .. import registry -from ..compat import cupy, cupyx +from ..compat import cublas, cupy, cupyx from ..types import DeviceTypes from ..util import ( is_cupy_array, @@ -257,7 +257,7 @@ def clip_gradient(self, gradient, threshold): # implementation. def frobenius_norm(X): X_vec = X.reshape(-1) - return cupy.cublas.nrm2(X_vec) + return cublas.nrm2(X_vec) grad_norm = cupy.maximum(frobenius_norm(gradient), 1e-12) gradient *= cupy.minimum(threshold, grad_norm) / grad_norm diff --git a/thinc/compat.py b/thinc/compat.py index 5d600796a..c7b47cbe6 100644 --- a/thinc/compat.py +++ b/thinc/compat.py @@ -4,9 +4,11 @@ try: # pragma: no cover import cupy + import cupy.cublas import cupyx has_cupy = True + cublas = cupy.cublas cupy_version = Version(cupy.__version__) try: cupy.cuda.runtime.getDeviceCount() @@ -20,6 +22,7 @@ else: cupy_from_dlpack = cupy.fromDlpack except (ImportError, AttributeError): + cublas = None cupy = None cupyx = None cupy_version = Version("0.0.0") diff --git a/thinc/shims/torchscript.py b/thinc/shims/torchscript.py index 6c05c8a9b..9d413f93a 100644 --- a/thinc/shims/torchscript.py +++ b/thinc/shims/torchscript.py @@ -30,7 +30,7 @@ class TorchScriptShim(PyTorchShim): def __init__( self, - model: Optional["torch.ScriptModule"], + model: Optional["torch.jit.ScriptModule"], config=None, optimizer: Any = None, mixed_precision: bool = False,