Skip to content

Commit

Permalink
Fix cupy.cublas import
Browse files Browse the repository at this point in the history
Reported in #920.
  • Loading branch information
danieldk committed Feb 7, 2024
1 parent b183006 commit b3f4f47
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
4 changes: 2 additions & 2 deletions thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy

from .. import registry
from ..compat import cupy, cupyx
from ..compat import cublas, cupy, cupyx
from ..types import DeviceTypes
from ..util import (
is_cupy_array,
Expand Down Expand Up @@ -257,7 +257,7 @@ def clip_gradient(self, gradient, threshold):
# implementation.
def frobenius_norm(X):
X_vec = X.reshape(-1)
return cupy.cublas.nrm2(X_vec)
return cublas.nrm2(X_vec)

grad_norm = cupy.maximum(frobenius_norm(gradient), 1e-12)
gradient *= cupy.minimum(threshold, grad_norm) / grad_norm
Expand Down
3 changes: 3 additions & 0 deletions thinc/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

try: # pragma: no cover
import cupy
import cupy.cublas
import cupyx

has_cupy = True
cublas = cupy.cublas
cupy_version = Version(cupy.__version__)
try:
cupy.cuda.runtime.getDeviceCount()
Expand All @@ -20,6 +22,7 @@
else:
cupy_from_dlpack = cupy.fromDlpack
except (ImportError, AttributeError):
cublas = None
cupy = None
cupyx = None
cupy_version = Version("0.0.0")
Expand Down

0 comments on commit b3f4f47

Please sign in to comment.