diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 035be0baf..cd569bafa 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -87,7 +87,9 @@ jobs: - name: Run mypy run: python -m mypy thinc --no-implicit-reexport - if: matrix.python_version != '3.6' + if: | + matrix.python_version != '3.6' && + matrix.python_version != '3.7' - name: Delete source directory run: rm -rf thinc @@ -150,14 +152,3 @@ jobs: - name: Run tests with extras run: python -m pytest --pyargs thinc --cov=thinc --cov-report=term -p thinc.tests.enable_tensorflow -p thinc.tests.enable_mxnet - - - name: Run tests for thinc-apple-ops - run: | - pip uninstall -y tensorflow - pip install thinc-apple-ops - python -m pytest --pyargs thinc_apple_ops - if: matrix.os == 'macos-latest' && matrix.python_version == '3.10' - - - name: Run tests with thinc-apple-ops - run: python -m pytest --pyargs thinc - if: matrix.os == 'macos-latest' && matrix.python_version == '3.10' diff --git a/requirements.txt b/requirements.txt index b7682e738..3e3c9901e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ pytest-cov>=2.7.0,<5.0.0 coverage>=5.0.0,<8.0.0 mock>=2.0.0,<3.0.0 flake8>=3.5.0,<3.6.0 -mypy>=1.0.0,<1.1.0; python_version >= "3.7" +mypy>=1.5.0,<1.6.0; platform_machine != "aarch64" and python_version >= "3.8" types-mock>=0.1.1 types-contextvars>=0.1.2; python_version < "3.7" types-dataclasses>=0.1.3; python_version < "3.7" diff --git a/setup.py b/setup.py index d2c717be9..e380c815c 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +import platform import sys from setuptools.command.build_ext import build_ext from sysconfig import get_path @@ -13,16 +14,16 @@ # http://docs.cython.org/en/latest/src/userguide/source_files_and_compilation.html#compiler-options Options.docstrings = True +ACCELERATE = "thinc.backends._accelerate" +APPLE_OPS = ["thinc.backends.apple_ops", ACCELERATE] PACKAGES = find_packages() MOD_NAMES = [ "thinc.backends.cblas", - "thinc.backends.linalg", "thinc.backends.numpy_ops", - "thinc.extra.search", "thinc.layers.sparselinear", "thinc.layers.premap_ids", -] +] + (APPLE_OPS if platform.system() == "Darwin" else []) COMPILE_OPTIONS = { "msvc": ["/Ox", "/EHsc"], "other": ["-O3", "-Wno-strict-prototypes", "-Wno-unused-function", "-std=c++11"], @@ -80,7 +81,16 @@ def setup_package(): ext_modules = [] for name in MOD_NAMES: mod_path = name.replace(".", "/") + ".pyx" - ext = Extension(name, [mod_path], language="c++", include_dirs=include_dirs) + if name == ACCELERATE: + ext = Extension( + name, + [mod_path], + language="c++", + include_dirs=include_dirs, + libraries=["blas"], + ) + else: + ext = Extension(name, [mod_path], language="c++", include_dirs=include_dirs) ext_modules.append(ext) print("Cythonizing sources") ext_modules = cythonize( diff --git a/thinc/about.py b/thinc/about.py index 394a8253e..d2a73d579 100644 --- a/thinc/about.py +++ b/thinc/about.py @@ -1,2 +1,2 @@ -__version__ = "8.2.2" +__version__ = "9.0.0" __release__ = True diff --git a/thinc/api.py b/thinc/api.py index 204aa386e..798ef6f08 100644 --- a/thinc/api.py +++ b/thinc/api.py @@ -119,11 +119,13 @@ ) from .optimizers import SGD, Adam, Optimizer, RAdam from .schedules import ( + Schedule, compounding, constant, constant_then, cyclic_triangular, decaying, + plateau, slanted_triangular, warmup_linear, ) @@ -160,6 +162,11 @@ xp2torch, ) +try: + from .backends import AppleOps +except ImportError: + AppleOps = None + # fmt: off __all__ = [ # .config @@ -179,8 +186,8 @@ # .optimizers "Adam", "RAdam", "SGD", "Optimizer", # .schedules - "cyclic_triangular", "warmup_linear", "constant", "constant_then", - "decaying", "slanted_triangular", "compounding", + "Schedule", "cyclic_triangular", "warmup_linear", "constant", "constant_then", + "decaying", "slanted_triangular", "compounding", "plateau", # .types "Ragged", "Padded", "ArgsKwargs", "Unserializable", # .util @@ -196,7 +203,7 @@ "has_cupy", # .backends "get_ops", "set_current_ops", "get_current_ops", "use_ops", - "Ops", "CupyOps", "MPSOps", "NumpyOps", "set_gpu_allocator", + "Ops", "AppleOps", "CupyOps", "MPSOps", "NumpyOps", "set_gpu_allocator", "use_pytorch_for_gpu_memory", "use_tensorflow_for_gpu_memory", # .layers "Dropout", "Embed", "expand_window", "HashEmbed", "LayerNorm", "Linear", diff --git a/thinc/backends/__init__.py b/thinc/backends/__init__.py index 8973c8836..5d33c2c34 100644 --- a/thinc/backends/__init__.py +++ b/thinc/backends/__init__.py @@ -19,6 +19,11 @@ from .numpy_ops import NumpyOps from .ops import Ops +try: + from .apple_ops import AppleOps +except ImportError: + AppleOps = None + context_ops: ContextVar[Optional[Ops]] = ContextVar("context_ops", default=None) context_pools: ContextVar[dict] = ContextVar("context_pools", default={}) @@ -26,6 +31,9 @@ # notebook might not have preserved contextvars across cells. _GLOBAL_STATE = {"ops": None} +# Thread-local state. +_LOCAL_STATE = threading.local() + def set_gpu_allocator(allocator: str) -> None: # pragma: no cover """Route GPU memory allocation via PyTorch or tensorflow. @@ -80,10 +88,6 @@ def use_tensorflow_for_gpu_memory() -> None: # pragma: no cover def _import_extra_cpu_backends(): - try: - from thinc_apple_ops import AppleOps - except ImportError: - pass try: from thinc_bigendian_ops import BigEndianOps except ImportError: @@ -152,22 +156,14 @@ def contextvars_eq_thread_ops() -> bool: return False -def _get_thread_state(): +def _get_thread_state() -> threading.local: """Get a thread-specific state variable that inherits from a global state when it's created.""" - thread: threading.Thread = threading.current_thread() - if not hasattr(thread, "__local"): - thread.__local = _create_thread_local(_GLOBAL_STATE) - return thread.__local - - -def _create_thread_local( - attrs: Dict[str, Any], local_class: Type[threading.local] = threading.local -): - obj = local_class() - for name, value in attrs.items(): - setattr(obj, name, value) - return obj + if not hasattr(_LOCAL_STATE, "initialized") or not _LOCAL_STATE.initialized: + for name, value in _GLOBAL_STATE.items(): + setattr(_LOCAL_STATE, name, value) + _LOCAL_STATE.initialized = True + return _LOCAL_STATE __all__ = [ @@ -176,6 +172,7 @@ def _create_thread_local( "use_ops", "ParamServer", "Ops", + "AppleOps", "CupyOps", "MPSOps", "NumpyOps", diff --git a/thinc/backends/_accelerate.pxd b/thinc/backends/_accelerate.pxd new file mode 100644 index 000000000..8bc0ce233 --- /dev/null +++ b/thinc/backends/_accelerate.pxd @@ -0,0 +1,40 @@ +cdef extern from "Accelerate/Accelerate.h": + enum CBLAS_ORDER: CblasRowMajor, CblasColMajor + enum CBLAS_TRANSPOSE: CblasNoTrans, CblasTrans, CblasConjTrans + enum CBLAS_UPLO: CblasUpper, CblasLower + enum CBLAS_DIAG: CblasNonUnit, CblasUnit + enum CBLAS_SIDE: CblasLeft, CblasRight + + # BLAS level 1 routines + + void cblas_sswap(int M, float *x, int incX, float *y, int incY) nogil + void cblas_sscal(int N, float alpha, float *x, int incX) nogil + void cblas_scopy(int N, float *x, int incX, float *y, int incY) nogil + void cblas_saxpy(int N, float alpha, float *x, int incX, float *y, int incY ) nogil + float cblas_sdot(int N, float *x, int incX, float *y, int incY ) nogil + float cblas_snrm2(int N, float *x, int incX) nogil + float cblas_sasum(int N, float *x, int incX) nogil + int cblas_isamax(int N, float *x, int incX) nogil + + # BLAS level 2 routines + void cblas_sgemv(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, int M, int N, + float alpha, float *A, int lda, float *x, int incX, + float beta, float *y, int incY) nogil + + void cblas_sger(CBLAS_ORDER Order, int M, int N, float alpha, float *x, + int incX, float *y, int incY, float *A, int lda) nogil + + # BLAS level 3 routines + void cblas_sgemm(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, int M, int N, int K, + float alpha, float *A, int lda, float *B, int ldb, + float beta, float *C, int ldc) nogil + + +cdef void sgemm(bint TransA, bint TransB, int M, int N, int K, + float alpha, const float* A, int lda, const float *B, + int ldb, float beta, float* C, int ldc) nogil + + +cdef void saxpy(int N, float alpha, const float* X, int incX, + float *Y, int incY) nogil diff --git a/thinc/backends/_accelerate.pyx b/thinc/backends/_accelerate.pyx new file mode 100644 index 000000000..094cb9443 --- /dev/null +++ b/thinc/backends/_accelerate.pyx @@ -0,0 +1,75 @@ +cimport numpy as np +from libc.stdint cimport uintptr_t + +import numpy + + +cpdef np.ndarray gemm(float[:, ::1] A, float[:, ::1] B, + bint trans1=False, bint trans2=False, + np.ndarray out=None): + cdef int nM = A.shape[0] if not trans1 else A.shape[1] + cdef int nK = A.shape[1] if not trans1 else A.shape[0] + cdef int nK_b = B.shape[0] if not trans2 else B.shape[1] + cdef int nN = B.shape[1] if not trans2 else B.shape[0] + + cdef float[:, ::1] C = out + + if out is None: + out = numpy.empty((nM, nN), dtype="f") + C = out + else: + if C.shape[0] != nM or C.shape[1] != nN: + msg = "Shape mismatch for output matrix, was: (%d, %d), expected (%d, %d)" + raise ValueError(msg % (C.shape[0], C.shape[1], nM, nN)) + + + if nK != nK_b: + msg = "Shape mismatch for gemm: (%d, %d), (%d, %d)" + raise ValueError(msg % (nM, nK, nK_b, nN)) + + if nM == 0 or nK == 0 or nN == 0: + return out + + cblas_sgemm( + CblasRowMajor, + CblasTrans if trans1 else CblasNoTrans, + CblasTrans if trans2 else CblasNoTrans, + nM, + nN, + nK, + 1.0, + &A[0, 0], + A.shape[1], + &B[0, 0], + B.shape[1], + 0.0, + &C[0, 0], + C.shape[1] + ) + return out + + +cdef void sgemm(bint TransA, bint TransB, int M, int N, int K, + float alpha, const float* A, int lda, const float *B, + int ldb, float beta, float* C, int ldc) nogil: + cblas_sgemm( + CblasRowMajor, + CblasTrans if TransA else CblasNoTrans, + CblasTrans if TransB else CblasNoTrans, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc + ) + + +cdef void saxpy(int N, float alpha, const float* X, int incX, + float *Y, int incY) nogil: + cblas_saxpy(N, alpha, X, incX, Y, incY) diff --git a/thinc/backends/apple_ops.pyx b/thinc/backends/apple_ops.pyx new file mode 100644 index 000000000..95a710c0d --- /dev/null +++ b/thinc/backends/apple_ops.pyx @@ -0,0 +1,39 @@ +from typing import Optional + +import numpy + +from ._accelerate import gemm + +from ._accelerate cimport saxpy, sgemm +from .cblas cimport CBlas, set_saxpy, set_sgemm + +from .. import registry +from ..types import Floats2d +from .numpy_ops import NumpyOps + + +@registry.ops("AppleOps") +class AppleOps(NumpyOps): + """Thinc Ops class that calls into Apple's native libraries for some + operations. Other operations fall back to numpy.""" + name = "apple" + xp = numpy + + def cblas(self) -> CBlas: + cdef CBlas cblas = CBlas() + set_saxpy(cblas, saxpy) + set_sgemm(cblas, sgemm) + return cblas + + def gemm( + self, + x: Floats2d, + y: Floats2d, + out: Optional[Floats2d] = None, + trans1: bool = False, + trans2: bool = False, + ) -> Floats2d: + """Perform General Matrix Multiplication (GeMM) and optionally store + the result in the specified output variable. + """ + return gemm(x, y, out=out, trans1=trans1, trans2=trans2) diff --git a/thinc/backends/cblas.pxd b/thinc/backends/cblas.pxd index 73cea1f2d..c608d8702 100644 --- a/thinc/backends/cblas.pxd +++ b/thinc/backends/cblas.pxd @@ -1,8 +1,11 @@ from libcpp.memory cimport shared_ptr ctypedef void (*sgemm_ptr)(bint transA, bint transB, int M, int N, int K, - float alpha, const float* A, int lda, const float *B, + float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C, int ldc) nogil +ctypedef void (*dgemm_ptr)(bint transA, bint transB, int M, int N, int K, + double alpha, const double* A, int lda, const double* B, + int ldb, double beta, double* C, int ldc) nogil ctypedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX, @@ -12,6 +15,8 @@ ctypedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX, ctypedef void (*daxpy_ptr)(int N, double alpha, const double* X, int incX, double *Y, int incY) nogil +ctypedef void (*sscal_ptr)(int N, float alpha, float* X, int incX) nogil +ctypedef void (*dscal_ptr)(int N, double alpha, double* X, int incX) nogil # Forward-declaration of the BlasFuncs struct. This struct must be opaque, so # that consumers of the CBlas class cannot become dependent on its size or @@ -32,6 +37,12 @@ cdef class CBlas: cdef daxpy_ptr daxpy(CBlas cblas) nogil cdef saxpy_ptr saxpy(CBlas cblas) nogil cdef sgemm_ptr sgemm(CBlas cblas) nogil +cdef dgemm_ptr dgemm(CBlas cblas) nogil +cdef sscal_ptr sscal(CBlas cblas) nogil +cdef dscal_ptr dscal(CBlas cblas) nogil cdef void set_daxpy(CBlas cblas, daxpy_ptr daxpy) nogil cdef void set_saxpy(CBlas cblas, saxpy_ptr saxpy) nogil cdef void set_sgemm(CBlas cblas, sgemm_ptr sgemm) nogil +cdef void set_dgemm(CBlas cblas, dgemm_ptr dgemm) nogil +cdef void set_sscal(CBlas cblas, sscal_ptr sscal) nogil +cdef void set_dscal(CBlas cblas, dscal_ptr dscal) nogil diff --git a/thinc/backends/cblas.pyx b/thinc/backends/cblas.pyx index e35169417..896b60481 100644 --- a/thinc/backends/cblas.pyx +++ b/thinc/backends/cblas.pyx @@ -4,10 +4,21 @@ from cython.operator cimport dereference as deref from libcpp.memory cimport make_shared +# Single- and double-precision wrappers for `blis.cy.scalv` +cdef void blis_sscal(int N, float alpha, float* X, int incX) nogil: + blis.cy.scalv(blis.cy.NO_CONJUGATE, N, alpha, X, incX) + +cdef void blis_dscal(int N, double alpha, double* X, int incX) nogil: + blis.cy.scalv(blis.cy.NO_CONJUGATE, N, alpha, X, incX) + + cdef struct BlasFuncs: daxpy_ptr daxpy saxpy_ptr saxpy sgemm_ptr sgemm + dgemm_ptr dgemm + sscal_ptr sscal + dscal_ptr dscal cdef class CBlas: @@ -20,6 +31,9 @@ cdef class CBlas: funcs.daxpy = blis.cy.daxpy funcs.saxpy = blis.cy.saxpy funcs.sgemm = blis.cy.sgemm + funcs.dgemm = blis.cy.dgemm + funcs.sscal = blis_sscal + funcs.dscal = blis_dscal self.ptr = make_shared[BlasFuncs](funcs) cdef daxpy_ptr daxpy(CBlas cblas) nogil: @@ -31,6 +45,15 @@ cdef saxpy_ptr saxpy(CBlas cblas) nogil: cdef sgemm_ptr sgemm(CBlas cblas) nogil: return deref(cblas.ptr).sgemm +cdef dgemm_ptr dgemm(CBlas cblas) nogil: + return deref(cblas.ptr).dgemm + +cdef sscal_ptr sscal(CBlas cblas) nogil: + return deref(cblas.ptr).sscal + +cdef dscal_ptr dscal(CBlas cblas) nogil: + return deref(cblas.ptr).dscal + cdef void set_daxpy(CBlas cblas, daxpy_ptr daxpy) nogil: deref(cblas.ptr).daxpy = daxpy @@ -39,3 +62,12 @@ cdef void set_saxpy(CBlas cblas, saxpy_ptr saxpy) nogil: cdef void set_sgemm(CBlas cblas, sgemm_ptr sgemm) nogil: deref(cblas.ptr).sgemm = sgemm + +cdef void set_dgemm(CBlas cblas, dgemm_ptr dgemm) nogil: + deref(cblas.ptr).dgemm = dgemm + +cdef void set_sscal(CBlas cblas, sscal_ptr sscal) nogil: + deref(cblas.ptr).sscal = sscal + +cdef void set_dscal(CBlas cblas, dscal_ptr dscal) nogil: + deref(cblas.ptr).dscal = dscal diff --git a/thinc/backends/cupy_ops.py b/thinc/backends/cupy_ops.py index 1e1e5b92b..472b6c542 100644 --- a/thinc/backends/cupy_ops.py +++ b/thinc/backends/cupy_ops.py @@ -1,7 +1,7 @@ import numpy from .. import registry -from ..compat import cupy, cupyx +from ..compat import cublas, cupy, cupyx from ..types import DeviceTypes from ..util import ( is_cupy_array, @@ -257,7 +257,7 @@ def clip_gradient(self, gradient, threshold): # implementation. def frobenius_norm(X): X_vec = X.reshape(-1) - return cupy.cublas.nrm2(X_vec) + return cublas.nrm2(X_vec) grad_norm = cupy.maximum(frobenius_norm(gradient), 1e-12) gradient *= cupy.minimum(threshold, grad_norm) / grad_norm diff --git a/thinc/backends/linalg.pxd b/thinc/backends/linalg.pxd deleted file mode 100644 index 37fb9ea2b..000000000 --- a/thinc/backends/linalg.pxd +++ /dev/null @@ -1,275 +0,0 @@ -# cython: infer_types=True -# cython: cdivision=True - -cimport cython -from cymem.cymem cimport Pool -from libc.stdint cimport int32_t -from libc.string cimport memcpy, memset - -ctypedef float weight_t - -DEF USE_BLAS = False -DEF EPS = 1e-5 - - -IF USE_BLAS: - cimport blis.cy - -cdef extern from "math.h" nogil: - weight_t exp(weight_t x) - weight_t sqrt(weight_t x) - - -cdef class Matrix: - cdef readonly Pool mem - cdef weight_t* data - cdef readonly int32_t nr_row - cdef readonly int32_t nr_col - - -cdef class Vec: - @staticmethod - cdef inline int arg_max(const weight_t* scores, const int n_classes) nogil: - if n_classes == 2: - return 0 if scores[0] > scores[1] else 1 - cdef int i - cdef int best = 0 - cdef weight_t mode = scores[0] - for i in range(1, n_classes): - if scores[i] > mode: - mode = scores[i] - best = i - return best - - @staticmethod - cdef inline weight_t max(const weight_t* x, int32_t nr) nogil: - if nr == 0: - return 0 - cdef int i - cdef weight_t mode = x[0] - for i in range(1, nr): - if x[i] > mode: - mode = x[i] - return mode - - @staticmethod - cdef inline weight_t sum(const weight_t* vec, int32_t nr) nogil: - cdef int i - cdef weight_t total = 0 - for i in range(nr): - total += vec[i] - return total - - @staticmethod - cdef inline weight_t norm(const weight_t* vec, int32_t nr) nogil: - cdef weight_t total = 0 - for i in range(nr): - total += vec[i] ** 2 - return sqrt(total) - - @staticmethod - cdef inline void add(weight_t* output, const weight_t* x, - weight_t inc, int32_t nr) nogil: - memcpy(output, x, sizeof(output[0]) * nr) - Vec.add_i(output, inc, nr) - - @staticmethod - cdef inline void add_i(weight_t* vec, weight_t inc, int32_t nr) nogil: - cdef int i - for i in range(nr): - vec[i] += inc - - @staticmethod - cdef inline void mul(weight_t* output, const weight_t* vec, weight_t scal, - int32_t nr) nogil: - memcpy(output, vec, sizeof(output[0]) * nr) - Vec.mul_i(output, scal, nr) - - @staticmethod - cdef inline void mul_i(weight_t* vec, weight_t scal, int32_t nr) nogil: - cdef int i - IF USE_BLAS: - blis.cy.scalv(BLIS_NO_CONJUGATE, nr, scal, vec, 1) - ELSE: - for i in range(nr): - vec[i] *= scal - - @staticmethod - cdef inline void pow(weight_t* output, const weight_t* vec, weight_t scal, - int32_t nr) nogil: - memcpy(output, vec, sizeof(output[0]) * nr) - Vec.pow_i(output, scal, nr) - - @staticmethod - cdef inline void pow_i(weight_t* vec, const weight_t scal, int32_t nr) nogil: - cdef int i - for i in range(nr): - vec[i] **= scal - - @staticmethod - @cython.cdivision(True) - cdef inline void div(weight_t* output, const weight_t* vec, weight_t scal, - int32_t nr) nogil: - memcpy(output, vec, sizeof(output[0]) * nr) - Vec.div_i(output, scal, nr) - - @staticmethod - @cython.cdivision(True) - cdef inline void div_i(weight_t* vec, const weight_t scal, int32_t nr) nogil: - cdef int i - for i in range(nr): - vec[i] /= scal - - @staticmethod - cdef inline void exp(weight_t* output, const weight_t* vec, int32_t nr) nogil: - memcpy(output, vec, sizeof(output[0]) * nr) - Vec.exp_i(output, nr) - - @staticmethod - cdef inline void exp_i(weight_t* vec, int32_t nr) nogil: - cdef int i - for i in range(nr): - vec[i] = exp(vec[i]) - - @staticmethod - cdef inline void reciprocal_i(weight_t* vec, int32_t nr) nogil: - cdef int i - for i in range(nr): - vec[i] = 1.0 / vec[i] - - @staticmethod - cdef inline weight_t mean(const weight_t* X, int32_t nr_dim) nogil: - cdef weight_t mean = 0. - for x in X[:nr_dim]: - mean += x - return mean / nr_dim - - @staticmethod - cdef inline weight_t variance(const weight_t* X, int32_t nr_dim) nogil: - # See https://www.johndcook.com/blog/standard_deviation/ - cdef double m = X[0] - cdef double v = 0. - for i in range(1, nr_dim): - diff = X[i]-m - m += diff / (i+1) - v += diff * (X[i] - m) - return v / nr_dim - - -cdef class VecVec: - @staticmethod - cdef inline void add(weight_t* output, - const weight_t* x, - const weight_t* y, - weight_t scale, - int32_t nr) nogil: - memcpy(output, x, sizeof(output[0]) * nr) - VecVec.add_i(output, y, scale, nr) - - @staticmethod - cdef inline void add_i(weight_t* x, - const weight_t* y, - weight_t scale, - int32_t nr) nogil: - cdef int i - IF USE_BLAS: - blis.cy.axpyv(BLIS_NO_CONJUGATE, nr, scale, y, 1, x, 1) - ELSE: - for i in range(nr): - x[i] += y[i] * scale - - @staticmethod - cdef inline void batch_add_i(weight_t* x, - const weight_t* y, - weight_t scale, - int32_t nr, int32_t nr_batch) nogil: - # For fixed x, matrix of y - cdef int i, _ - for _ in range(nr_batch): - VecVec.add_i(x, - y, scale, nr) - y += nr - - @staticmethod - cdef inline void add_pow(weight_t* output, - const weight_t* x, const weight_t* y, weight_t power, int32_t nr) nogil: - memcpy(output, x, sizeof(output[0]) * nr) - VecVec.add_pow_i(output, y, power, nr) - - - @staticmethod - cdef inline void add_pow_i(weight_t* x, - const weight_t* y, weight_t power, int32_t nr) nogil: - cdef int i - for i in range(nr): - x[i] += y[i] ** power - - @staticmethod - cdef inline void mul(weight_t* output, - const weight_t* x, const weight_t* y, int32_t nr) nogil: - memcpy(output, x, sizeof(output[0]) * nr) - VecVec.mul_i(output, y, nr) - - @staticmethod - cdef inline void mul_i(weight_t* x, - const weight_t* y, int32_t nr) nogil: - cdef int i - for i in range(nr): - x[i] *= y[i] - - @staticmethod - cdef inline weight_t dot( - const weight_t* x, const weight_t* y, int32_t nr) nogil: - cdef int i - cdef weight_t total = 0 - for i in range(nr): - total += x[i] * y[i] - return total - - @staticmethod - cdef inline int arg_max_if_true( - const weight_t* scores, const int* is_valid, const int n_classes) nogil: - cdef int i - cdef int best = -1 - for i in range(n_classes): - if is_valid[i] and (best == -1 or scores[i] > scores[best]): - best = i - return best - - @staticmethod - cdef inline int arg_max_if_zero( - const weight_t* scores, const weight_t* costs, const int n_classes) nogil: - cdef int i - cdef int best = -1 - for i in range(n_classes): - if costs[i] == 0 and (best == -1 or scores[i] > scores[best]): - best = i - return best - - -cdef class Mat: - @staticmethod - cdef inline void mean_row(weight_t* Ex, - const weight_t* mat, int32_t nr_row, int32_t nr_col) nogil: - memset(Ex, 0, sizeof(Ex[0]) * nr_col) - for i in range(nr_row): - VecVec.add_i(Ex, &mat[i * nr_col], 1.0, nr_col) - Vec.mul_i(Ex, 1.0 / nr_row, nr_col) - - @staticmethod - cdef inline void var_row(weight_t* Vx, - const weight_t* mat, const weight_t* Ex, - int32_t nr_row, int32_t nr_col, weight_t eps) nogil: - # From https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - if nr_row == 0 or nr_col == 0: - return - cdef weight_t sum_, sum2 - for i in range(nr_col): - sum_ = 0.0 - sum2 = 0.0 - for j in range(nr_row): - x = mat[j * nr_col + i] - sum2 += (x - Ex[i]) ** 2 - sum_ += x - Ex[i] - Vx[i] = (sum2 - sum_**2 / nr_row) / nr_row - Vx[i] += eps diff --git a/thinc/backends/linalg.pyx b/thinc/backends/linalg.pyx deleted file mode 100644 index 64a360731..000000000 --- a/thinc/backends/linalg.pyx +++ /dev/null @@ -1,5 +0,0 @@ -# cython: profile=False -try: - import blis.py -except ImportError: - pass diff --git a/thinc/backends/mps_ops.py b/thinc/backends/mps_ops.py index c6ba71f11..fb242f0f1 100644 --- a/thinc/backends/mps_ops.py +++ b/thinc/backends/mps_ops.py @@ -3,6 +3,7 @@ import numpy from .. import registry +from ..compat import has_apple_ops from .numpy_ops import NumpyOps from .ops import Ops @@ -12,11 +13,11 @@ # during type checking. _Ops = Ops else: - try: - from thinc_apple_ops import AppleOps + if has_apple_ops: + from .apple_ops import AppleOps _Ops = AppleOps - except ImportError: + else: _Ops = NumpyOps diff --git a/thinc/backends/numpy_ops.pyx b/thinc/backends/numpy_ops.pyx index 78eee6ada..bb38bae3c 100644 --- a/thinc/backends/numpy_ops.pyx +++ b/thinc/backends/numpy_ops.pyx @@ -5,7 +5,6 @@ from typing import Optional import numpy -cimport blis.cy cimport cython cimport numpy as np from cymem.cymem cimport Pool @@ -20,19 +19,10 @@ from .. import registry from ..types import ArrayXd, DeviceTypes, DTypes, Shape from ..util import copy_array, get_array_module -from .cblas cimport CBlas, daxpy, saxpy -from .linalg cimport Vec, VecVec +from .cblas cimport CBlas, daxpy, dgemm, saxpy, sgemm, sscal -from .ops import Ops - -try: - import blis.py - has_blis = True -except ImportError: - has_blis = False - - -ctypedef float weight_t +from ..compat import has_blis +from .ops import Ops, _split_weights, _transpose_weights, _untranspose_unsplit_weights cdef extern from "math.h": @@ -95,11 +85,45 @@ class NumpyOps(Ops): raise ValueError(f"Provided 'y' array should be 2-dimensional, but found {y.ndim} dimension(s).") if not self.use_blis: # delegate to base Ops return super().gemm(x, y, out=out, trans1=trans1, trans2=trans2) + x = self.as_contig(x) y = self.as_contig(y) + + cdef int nM = x.shape[0] if not trans1 else x.shape[1] + cdef int nK = x.shape[1] if not trans1 else x.shape[0] + cdef int nK_b = y.shape[0] if not trans2 else y.shape[1] + cdef int nN = y.shape[1] if not trans2 else y.shape[0] + if nK != nK_b: + msg = "Shape mismatch for blis.gemm: (%d, %d), (%d, %d)" + raise ValueError(msg % (nM, nK, nK_b, nN)) + if out is not None: out = self.as_contig(out) - return blis.py.gemm(x, y, out=out, trans1=trans1, trans2=trans2, beta=0.) + else: + # Can be uninitialized as 'beta' is zero. + out = numpy.empty((nM, nN), dtype=x.dtype) + + cblas = self.cblas() + if x.dtype == "float32" and y.dtype == "float32" and out.dtype == "float32": + sgemm(cblas)(trans1, trans2, + nM, nN, nK, + 1.0, + (x.data), x.shape[1], + (y.data), y.shape[1], + 0.0, + (out.data), out.shape[1]) + elif x.dtype == "float64" and y.dtype == "float64" and out.dtype == "float64": + dgemm(cblas)(trans1, trans2, + nM, nN, nK, + 1.0, + (x.data), x.shape[1], + (y.data), y.shape[1], + 0.0, + (out.data), out.shape[1]) + else: + raise ValueError(f"unsupported or mismatching array data types; got '{x.dtype}', '{y.dtype}', '{out.dtype}'") + + return out def relu(self, np.ndarray X, inplace=False): cdef np.ndarray Y @@ -119,12 +143,12 @@ class NumpyOps(Ops): _check_compatible_shape(dY, Y) cdef size_t size = Y.size - cdef weight_t* dX_ptr - cdef const weight_t* Y_ptr = Y.data + cdef float* dX_ptr + cdef const float* Y_ptr = Y.data cdef np.ndarray dX if dY.dtype == "float32" and Y.dtype == "float32": dX = _inplace_or_copy(dY, inplace) - dX_ptr = dX.data + dX_ptr = dX.data for i in range(size): if Y_ptr[i] <= 0: dX_ptr[i] = 0. @@ -142,7 +166,7 @@ class NumpyOps(Ops): ): assert H0.shape[0] == C0.shape[0] assert H0.shape[1] == C0.shape[1] - Y, fwd_state = lstm_forward_training(params, H0, C0, X, size_at_t) + Y, fwd_state = lstm_forward_training(self.cblas(), params, H0, C0, X, size_at_t) return Y, fwd_state def lstm_forward_inference( @@ -153,13 +177,13 @@ class NumpyOps(Ops): np.ndarray X, np.ndarray size_at_t ): - Y, _ = lstm_forward_training(params, H0, C0, X, size_at_t) + Y, _ = lstm_forward_training(self.cblas(), params, H0, C0, X, size_at_t) return Y def backprop_lstm( self, np.ndarray dY, np.ndarray lengths, np.ndarray params, fwd_state ): - dX, d_params = backprop_lstm(dY, lengths, params, fwd_state) + dX, d_params = backprop_lstm(self.cblas(), dY, lengths, params, fwd_state) return dX, d_params def maxout(self, reals3d_ft X): @@ -486,7 +510,7 @@ class NumpyOps(Ops): and values.ndim == 2 \ and values.shape[0] == indices.shape[0] \ and values.shape[1] == table.shape[1]: - cpu_scatter_add(table.data, + cpu_scatter_add(self.cblas(), table.data, indices.data, values.data, indices.shape[0], table.shape[1]) else: @@ -502,10 +526,11 @@ class NumpyOps(Ops): _check_compatible_shape(weights, mom1) _check_compatible_shape(weights, mom2) - _adam_momentum(gradient.data, mom1.data, mom2.data, + cdef CBlas cblas = self.cblas() + _adam_momentum(cblas, gradient.data, mom1.data, mom2.data, weights.shape[0], beta1, beta2, eps, learn_rate) - VecVec.add_i(weights.data, - gradient.data, -learn_rate, weights.shape[0]) + saxpy(cblas)(weights.shape[0], -learn_rate, gradient.data, 1, weights.data, 1) + memset(gradient.data, 0, gradient.size * sizeof(float)) return weights, gradient, mom1, mom2 @@ -542,21 +567,6 @@ def check_seq2col_lengths(ops, lengths, B): return lengths -def cpu_clip_gradient(weight_t[::1] gradient, weight_t threshold): - grad_norm = Vec.norm(&gradient[0], gradient.shape[0]) - if grad_norm >= threshold: - Vec.mul_i(&gradient[0], threshold / grad_norm, gradient.shape[0]) - - -def add_gradient_noise(float[::1] gradient, weight_t noise_level, - weight_t timestep): - cdef weight_t variance = noise_level / ((1 + timestep) ** 0.55) - if variance >= 0.000001: - gradient += numpy.asarray( - numpy.random.normal(scale=variance, loc=0., size=len(gradient)), - dtype='float32') - - cdef void cpu_position_encode(float* output, float period, int N, int D) nogil: cdef float pos, d cdef int j @@ -575,39 +585,38 @@ cdef void cpu_position_encode(float* output, float period, int N, int D) nogil: output += D -cdef void cpu_scatter_add(float* dest, +cdef void cpu_scatter_add(CBlas cblas, float* dest, const int* indices, const float* src, int nr_id, int nr_col) nogil: cdef int i for i in range(nr_id): id_ = indices[i] if id_ >= 0: - VecVec.add_i(&dest[id_*nr_col], - &src[i*nr_col], 1., nr_col) + saxpy(cblas)(nr_col, 1., &src[i*nr_col], 1, &dest[id_*nr_col], 1) @cython.cdivision(True) -cdef void _adam_momentum(weight_t* gradient, weight_t* mom1, weight_t* mom2, - int nr_weight, weight_t beta1, weight_t beta2, weight_t eps, - weight_t learn_rate) nogil: +cdef void _adam_momentum(CBlas cblas, float* gradient, float* mom1, float* mom2, + int nr_weight, float beta1, float beta2, float eps, + float learn_rate) nogil: # Calculate Adam on CPU, fused. # Assumes the learning rate adjustment is calculated by the caller; # a_t = learn_rate * sqrt(1-beta2**timestep) / (1-beta1**timestep) - cdef weight_t one_minus_beta1 = 1-beta1 - cdef weight_t one_minus_beta2 = 1-beta2 - cdef weight_t m1, m2, g + cdef float one_minus_beta1 = 1-beta1 + cdef float one_minus_beta2 = 1-beta2 + cdef float m1, m2, g cdef int i # Blockwise implementation is a bit faster. Adam is slooow :( - cdef weight_t[64] buff + cdef float[64] buff cdef int steps = nr_weight // 64 if steps * 64 < nr_weight: steps += 1 idx = 0 for i in range(steps): step_size = min(64, nr_weight-idx) - Vec.mul_i(mom1, beta1, step_size) - VecVec.add_i(mom1, gradient, one_minus_beta1, step_size) - Vec.mul_i(mom2, beta2, step_size) + sscal(cblas)(step_size, beta1, mom1, 1) + saxpy(cblas)(step_size, one_minus_beta1, gradient, 1, mom1, 1) + sscal(cblas)(step_size, beta2, mom2, 1) for j in range(step_size): mom2[j] += one_minus_beta2 * gradient[j] ** 2 for j in range(step_size): @@ -624,19 +633,7 @@ cdef void _adam_momentum(weight_t* gradient, weight_t* mom1, weight_t* mom2, idx += step_size -@cython.cdivision(True) -cdef void cpu_update_averages(weight_t* ema, - const weight_t* weights, int nr_weight, weight_t t, weight_t max_decay) nogil: - cdef weight_t decay = (1.0 + t) / (10.0 + t) - if decay > max_decay: - decay = max_decay - cdef weight_t one_minus_decay = 1-decay - cdef int i - for i in range(nr_weight): # num_threads=4, schedule='static'): - ema[i] -= one_minus_decay * (ema[i] - weights[i]) - - -def lstm_forward_training( +def lstm_forward_training(CBlas cblas, np.ndarray params, np.ndarray c_init, np.ndarray h_init, np.ndarray X, np.ndarray lengths ): @@ -678,6 +675,7 @@ def lstm_forward_training( Cid = C[i, d] Gid = G[i, d] _lstm_forward_training( + cblas, d, N, nO, nI, nT, Gid, Yid.data, @@ -698,6 +696,7 @@ def lstm_forward_training( cdef int _lstm_forward_training( + CBlas cblas, int d, int N, int nO, int nI, int nT, np.ndarray G, float* Y, @@ -711,13 +710,13 @@ cdef int _lstm_forward_training( float* Ct2, ) except -1: cdef double one = 1.0 - blis.cy.gemm(blis.cy.NO_TRANSPOSE, blis.cy.TRANSPOSE, + sgemm(cblas)(False, True, N, nO*4, nI, one, - X, nI, 1, - Wx, nI, 1, + X, nI, + Wx, nI, one, - G.data, nO*4, 1 + G.data, nO*4 ) cdef int t, batch_size cdef int seq_i = 0 if d == 0 else N @@ -735,13 +734,13 @@ cdef int _lstm_forward_training( Gt3_ = G[seq_i : seq_i+batch_size] Gt3 = Gt3_.data # Now do the actual calculation - blis.cy.gemm(blis.cy.NO_TRANSPOSE, blis.cy.TRANSPOSE, + sgemm(cblas)(False, True, batch_size, nO*4, nO, one, - Yt2, nO, 1, - Wh, nO, 1, + Yt2, nO, + Wh, nO, one, - Gt3, nO*4, 1 + Gt3, nO*4 ) # This is super weird: if we remove this add, it gets slower? I guess # it does cache prefetching or something? @@ -765,7 +764,7 @@ cdef int _lstm_forward_training( memcpy(Ct2, Ct3, sizeof(Ct3[0]) * batch_size * nO) -def backprop_lstm(np.ndarray dY, np.ndarray lengths, np.ndarray params, fwd_state): +def backprop_lstm(CBlas cblas, np.ndarray dY, np.ndarray lengths, np.ndarray params, fwd_state): xp = numpy cdef np.ndarray Y cdef np.ndarray G @@ -842,7 +841,7 @@ def backprop_lstm(np.ndarray dY, np.ndarray lengths, np.ndarray params, fwd_stat assert (dYid.shape[0], dYid.shape[1]) == (N, nO) assert (dC.shape[0], dC.shape[1]) == (N, nO) assert (dG.shape[0], dG.shape[1]) == (N, nO*4) - _lstm_backward_training(d, N, nO, dX.shape[1], nT, + _lstm_backward_training(cblas, d, N, nO, dX.shape[1], nT, dX.data, dYid.data, dC.data, @@ -867,18 +866,8 @@ def backprop_lstm(np.ndarray dY, np.ndarray lengths, np.ndarray params, fwd_stat return dX, numpy.concatenate(grad_parts) -def _split_directions(X, dirs): - if dirs == 1: - return [X] - else: - X_ = X.reshape((X.shape[0], -1, dirs)) - Xs = [] - for d in range(dirs): - Xs.append(numpy.ascontiguousarray(X_[:, d])) - return Xs - - cdef int _lstm_backward_training( + CBlas cblas, int d, int N, int nO, int nI, int nT, float* dX, float* dY, @@ -923,36 +912,36 @@ cdef int _lstm_backward_training( ) # Backprop hidden-to-hidden w.r.t. hidden. # dYt2 += dGt3 @ Wh - blis.cy.gemm(blis.cy.NO_TRANSPOSE, blis.cy.NO_TRANSPOSE, + sgemm(cblas)(False, False, batch_size, nO, nO*4, one, - dGt3, nO*4, 1, - Wh, nO, 1, + dGt3, nO*4, + Wh, nO, one, - dYt2, nO, 1 + dYt2, nO ) seq_t3 = seq_t2 size_t3 = size_t2 # Backprop input-to-hidden w.r.t. weights. # dWx += dG @ X - blis.cy.gemm(blis.cy.TRANSPOSE, blis.cy.NO_TRANSPOSE, + sgemm(cblas)(True, False, nO*4, nI, N, one, - dG, nO*4, 1, - X, nI, 1, + dG, nO*4, + X, nI, one, - dWx, nI, 1 + dWx, nI ) # Backprop hidden-to-hidden w.r.t weights. # dWh += dG @ Y - blis.cy.gemm(blis.cy.TRANSPOSE, blis.cy.NO_TRANSPOSE, + sgemm(cblas)(True, False, nO*4, nO, N, one, - dG, nO*4, 1, - Y, nO, 1, + dG, nO*4, + Y, nO, one, - dWh, nO, 1 + dWh, nO ) # Backprop bias for i in range(N): @@ -960,64 +949,16 @@ cdef int _lstm_backward_training( d_bias[j] += dG[i*nO*4+j] # Backprop input-to-hidden w.r.t. input - blis.cy.gemm(blis.cy.NO_TRANSPOSE, blis.cy.NO_TRANSPOSE, + sgemm(cblas)(False, False, N, nI, nO*4, one, - dG, nO*4, 1, - Wx, nI, 1, + dG, nO*4, + Wx, nI, one, - dX, nI, 1 + dX, nI ) -def _split_weights(np.ndarray params, int i, int nO, int nI, int params_i): - Wx_size = 4 * nO * nI - bx_size = 4 * nO - Wh_size = 4 * nO * nO - bh_size = 4 * nO - Wx = params[params_i : params_i + Wx_size].reshape((4 * nO, nI)) - params_i += Wx_size - bx = params[params_i : params_i + bx_size].reshape((4 * nO,)) - params_i += bx_size - Wh = params[params_i : params_i + Wh_size].reshape((4 * nO, nO)) - params_i += Wh_size - bh = params[params_i : params_i + bh_size].reshape((4 * nO,)) - params_i += bh_size - return ((Wx, bx), (Wh, bh)), params_i - - -def _transpose_weights(params): - # Transpose the parameters so that the gates are the last dimension. This - # makes it easier to fuse. - (Wx, bx), (Wh, bh) = params - Wx = Wx.reshape((4, -1, Wx.shape[-1])) - Wx = Wx.transpose((1, 0, 2)).reshape((-1, Wx.shape[-1])) - bx = bx.reshape((4, -1)).transpose((1, 0)).reshape((-1,)) - Wh = Wh.reshape((4, -1, Wh.shape[-1])) - Wh = Wh.transpose((1, 0, 2)).reshape((-1, Wh.shape[-1])) - bh = bh.reshape((4, -1)).transpose((1, 0)).reshape((-1,)) - ascontig = numpy.ascontiguousarray - Wx = ascontig(Wx) - Wh = ascontig(Wh) - bias = ascontig(bx) + bh - return Wx, Wh, bias - - -def _untranspose_unsplit_weights(params): - Wx, Wh, bias = params - nO = Wh.shape[1] - nI = Wx.shape[1] - Wx = Wx.reshape((-1, 4, nI)).transpose((1, 0, 2)).reshape((-1, nI)) - Wh = Wh.reshape((-1, 4, nO)).transpose((1, 0, 2)).reshape((-1, nO)) - bias = bias.reshape((-1, 4)).transpose((1, 0)).reshape((-1,)) - zeros = numpy.zeros(bias.shape, dtype="f") - return numpy.concatenate((Wx.ravel(), bias, Wh.ravel(), zeros)) - - -cdef inline float sigmoid(float X) nogil: - return 1./(1. + expf(-X)) - - cdef inline float dsigmoid(float y) nogil: return y*(1-y) diff --git a/thinc/compat.py b/thinc/compat.py index 5d600796a..2ec91de48 100644 --- a/thinc/compat.py +++ b/thinc/compat.py @@ -1,12 +1,15 @@ +import platform import warnings from packaging.version import Version try: # pragma: no cover import cupy + import cupy.cublas import cupyx has_cupy = True + cublas = cupy.cublas cupy_version = Version(cupy.__version__) try: cupy.cuda.runtime.getDeviceCount() @@ -20,6 +23,7 @@ else: cupy_from_dlpack = cupy.fromDlpack except (ImportError, AttributeError): + cublas = None cupy = None cupyx = None cupy_version = Version("0.0.0") @@ -107,6 +111,18 @@ def enable_mxnet(): has_os_signpost = False +try: # pragma: no cover + import blis + + has_blis = True +except ImportError: + blis = None + has_blis = False + + +# AppleOps is available unconditionally on macOS. +has_apple_ops = platform.system() == "Darwin" + has_gpu = has_cupy_gpu or has_torch_mps_gpu __all__ = [ diff --git a/thinc/extra/__init__.pxd b/thinc/extra/__init__.pxd deleted file mode 100644 index e69de29bb..000000000 diff --git a/thinc/extra/search.pxd b/thinc/extra/search.pxd deleted file mode 100644 index a27ba0525..000000000 --- a/thinc/extra/search.pxd +++ /dev/null @@ -1,90 +0,0 @@ -from cymem.cymem cimport Pool -from libc.stdint cimport uint32_t, uint64_t -from libcpp.pair cimport pair -from libcpp.queue cimport priority_queue -from libcpp.vector cimport vector - -ctypedef uint64_t hash_t -ctypedef uint64_t class_t -ctypedef float weight_t - - -ctypedef pair[weight_t, size_t] Entry -ctypedef priority_queue[Entry] Queue - - -ctypedef int (*trans_func_t)(void* dest, void* src, class_t clas, void* x) except -1 - -ctypedef void* (*init_func_t)(Pool mem, int n, void* extra_args) except NULL - -ctypedef int (*del_func_t)(Pool mem, void* state, void* extra_args) except -1 - -ctypedef int (*finish_func_t)(void* state, void* extra_args) except -1 - -ctypedef hash_t (*hash_func_t)(void* state, void* x) except 0 - - -cdef struct _State: - void* content - class_t* hist - weight_t score - weight_t loss - int i - int t - bint is_done - - -cdef class Beam: - cdef Pool mem - cdef class_t nr_class - cdef class_t width - cdef class_t size - cdef public weight_t min_density - cdef int t - cdef readonly bint is_done - cdef list histories - cdef list _parent_histories - cdef weight_t** scores - cdef int** is_valid - cdef weight_t** costs - cdef _State* _parents - cdef _State* _states - cdef del_func_t del_func - - cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1 - - cdef inline void* at(self, int i) nogil: - return self._states[i].content - - cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1 - cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func, - void* extra_args) except -1 - cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1 - - - cdef inline void set_cell(self, int i, int j, weight_t score, int is_valid, weight_t cost) nogil: - self.scores[i][j] = score - self.is_valid[i][j] = is_valid - self.costs[i][j] = cost - - cdef int set_row(self, int i, const weight_t* scores, const int* is_valid, - const weight_t* costs) except -1 - cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1 - - -cdef class MaxViolation: - cdef Pool mem - cdef weight_t cost - cdef weight_t delta - cdef readonly weight_t p_score - cdef readonly weight_t g_score - cdef readonly double Z - cdef readonly double gZ - cdef class_t n - cdef readonly list p_hist - cdef readonly list g_hist - cdef readonly list p_probs - cdef readonly list g_probs - - cpdef int check(self, Beam pred, Beam gold) except -1 - cpdef int check_crf(self, Beam pred, Beam gold) except -1 diff --git a/thinc/extra/search.pyx b/thinc/extra/search.pyx deleted file mode 100644 index 651e6ff04..000000000 --- a/thinc/extra/search.pyx +++ /dev/null @@ -1,303 +0,0 @@ -# cython: experimental_cpp_class_def=True, cdivision=True, infer_types=True -cimport cython -from libc.math cimport exp, log -from libc.string cimport memcpy, memset - -import math - -from cymem.cymem cimport Pool -from preshed.maps cimport PreshMap - - -cdef class Beam: - def __init__(self, class_t nr_class, class_t width, weight_t min_density=0.0): - assert nr_class != 0 - assert width != 0 - self.nr_class = nr_class - self.width = width - self.min_density = min_density - self.size = 1 - self.t = 0 - self.mem = Pool() - self._parents = <_State*>self.mem.alloc(self.width, sizeof(_State)) - self._states = <_State*>self.mem.alloc(self.width, sizeof(_State)) - cdef int i - self.histories = [[] for i in range(self.width)] - self._parent_histories = [[] for i in range(self.width)] - - self.scores = self.mem.alloc(self.width, sizeof(weight_t*)) - self.is_valid = self.mem.alloc(self.width, sizeof(weight_t*)) - self.costs = self.mem.alloc(self.width, sizeof(weight_t*)) - for i in range(self.width): - self.scores[i] = self.mem.alloc(self.nr_class, sizeof(weight_t)) - self.is_valid[i] = self.mem.alloc(self.nr_class, sizeof(int)) - self.costs[i] = self.mem.alloc(self.nr_class, sizeof(weight_t)) - - def __len__(self): - return self.size - - property score: - def __get__(self): - return self._states[0].score - - property min_score: - def __get__(self): - return self._states[self.size-1].score - - property loss: - def __get__(self): - return self._states[0].loss - - property probs: - def __get__(self): - return _softmax([self._states[i].score for i in range(self.size)]) - - property scores: - def __get__(self): - return [self._states[i].score for i in range(self.size)] - - property histories: - def __get__(self): - return self.histories - - cdef int set_row(self, int i, const weight_t* scores, const int* is_valid, - const weight_t* costs) except -1: - cdef int j - for j in range(self.nr_class): - self.scores[i][j] = scores[j] - self.is_valid[i][j] = is_valid[j] - self.costs[i][j] = costs[j] - - cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1: - cdef int i, j - for i in range(self.width): - memcpy(self.scores[i], scores[i], sizeof(weight_t) * self.nr_class) - memcpy(self.is_valid[i], is_valid[i], sizeof(bint) * self.nr_class) - memcpy(self.costs[i], costs[i], sizeof(int) * self.nr_class) - - cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1: - for i in range(self.width): - self._states[i].content = init_func(self.mem, n, extra_args) - self._parents[i].content = init_func(self.mem, n, extra_args) - self.del_func = del_func - - def __dealloc__(self): - for i in range(self.width): - self.del_func(self.mem, self._states[i].content, NULL) - self.del_func(self.mem, self._parents[i].content, NULL) - - @cython.cdivision(True) - cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func, - void* extra_args) except -1: - cdef weight_t** scores = self.scores - cdef int** is_valid = self.is_valid - cdef weight_t** costs = self.costs - - cdef Queue* q = new Queue() - self._fill(q, scores, is_valid) - # For a beam of width k, we only ever need 2k state objects. How? - # Each transition takes a parent and a class and produces a new state. - # So, we don't need the whole history --- just the parent. So at - # each step, we take a parent, and apply one or more extensions to - # it. - self._parents, self._states = self._states, self._parents - self._parent_histories, self.histories = self.histories, self._parent_histories - cdef weight_t score - cdef int p_i - cdef int i = 0 - cdef class_t clas - cdef _State* parent - cdef _State* state - cdef hash_t key - cdef PreshMap seen_states = PreshMap(self.width) - cdef uint64_t is_seen - cdef uint64_t one = 1 - while i < self.width and not q.empty(): - data = q.top() - p_i = data.second / self.nr_class - clas = data.second % self.nr_class - score = data.first - q.pop() - parent = &self._parents[p_i] - # Indicates terminal state reached; i.e. state is done - if parent.is_done: - # Now parent will not be changed, so we don't have to copy. - # Once finished, should also be unbranching. - self._states[i], parent[0] = parent[0], self._states[i] - parent.i = self._states[i].i - parent.t = self._states[i].t - parent.is_done = self._states[i].t - self._states[i].score = score - self.histories[i] = list(self._parent_histories[p_i]) - i += 1 - else: - state = &self._states[i] - # The supplied transition function should adjust the destination - # state to be the result of applying the class to the source state - transition_func(state.content, parent.content, clas, extra_args) - key = hash_func(state.content, extra_args) if hash_func is not NULL else 0 - is_seen = seen_states.get(key) - if key == 0 or key == 1 or not is_seen: - if key != 0 and key != 1: - seen_states.set(key, one) - state.score = score - state.loss = parent.loss + costs[p_i][clas] - self.histories[i] = list(self._parent_histories[p_i]) - self.histories[i].append(clas) - i += 1 - del q - self.size = i - assert self.size >= 1 - for i in range(self.width): - memset(self.scores[i], 0, sizeof(weight_t) * self.nr_class) - memset(self.costs[i], 0, sizeof(weight_t) * self.nr_class) - memset(self.is_valid[i], 0, sizeof(int) * self.nr_class) - self.t += 1 - - cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1: - cdef int i - for i in range(self.size): - if not self._states[i].is_done: - self._states[i].is_done = finish_func(self._states[i].content, extra_args) - for i in range(self.size): - if not self._states[i].is_done: - self.is_done = False - break - else: - self.is_done = True - - @cython.cdivision(True) - cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1: - """Populate the queue from a k * n matrix of scores, where k is the - beam-width, and n is the number of classes. - """ - cdef Entry entry - cdef weight_t score - cdef _State* s - cdef int i, j, move_id - assert self.size >= 1 - cdef vector[Entry] entries - for i in range(self.size): - s = &self._states[i] - move_id = i * self.nr_class - if s.is_done: - # Update score by path average, following TACL '13 paper. - if self.histories[i]: - entry.first = s.score + (s.score / self.t) - else: - entry.first = s.score - entry.second = move_id - entries.push_back(entry) - else: - for j in range(self.nr_class): - if is_valid[i][j]: - entry.first = s.score + scores[i][j] - entry.second = move_id + j - entries.push_back(entry) - cdef double max_, Z, cutoff - if self.min_density == 0.0: - for i in range(entries.size()): - q.push(entries[i]) - elif not entries.empty(): - max_ = entries[0].first - Z = 0. - cutoff = 0. - # Softmax into probabilities, so we can prune - for i in range(entries.size()): - if entries[i].first > max_: - max_ = entries[i].first - for i in range(entries.size()): - Z += exp(entries[i].first-max_) - cutoff = (1. / Z) * self.min_density - for i in range(entries.size()): - prob = exp(entries[i].first-max_) / Z - if prob >= cutoff: - q.push(entries[i]) - - -cdef class MaxViolation: - def __init__(self): - self.p_score = 0.0 - self.g_score = 0.0 - self.Z = 0.0 - self.gZ = 0.0 - self.delta = -1 - self.cost = 0 - self.p_hist = [] - self.g_hist = [] - self.p_probs = [] - self.g_probs = [] - - cpdef int check(self, Beam pred, Beam gold) except -1: - cdef _State* p = &pred._states[0] - cdef _State* g = &gold._states[0] - cdef weight_t d = p.score - g.score - if p.loss >= 1 and (self.cost == 0 or d > self.delta): - self.cost = p.loss - self.delta = d - self.p_hist = list(pred.histories[0]) - self.g_hist = list(gold.histories[0]) - self.p_score = p.score - self.g_score = g.score - self.Z = 1e-10 - self.gZ = 1e-10 - for i in range(pred.size): - if pred._states[i].loss > 0: - self.Z += exp(pred._states[i].score) - for i in range(gold.size): - if gold._states[i].loss == 0: - prob = exp(gold._states[i].score) - self.Z += prob - self.gZ += prob - - cpdef int check_crf(self, Beam pred, Beam gold) except -1: - d = pred.score - gold.score - seen_golds = set([tuple(gold.histories[i]) for i in range(gold.size)]) - if pred.loss > 0 and (self.cost == 0 or d > self.delta): - p_hist = [] - p_scores = [] - g_hist = [] - g_scores = [] - for i in range(pred.size): - if pred._states[i].loss > 0: - p_scores.append(pred._states[i].score) - p_hist.append(list(pred.histories[i])) - # This can happen from non-monotonic actions - # If we find a better gold analysis this way, be sure to keep it. - elif pred._states[i].loss <= 0 \ - and tuple(pred.histories[i]) not in seen_golds: - g_scores.append(pred._states[i].score) - g_hist.append(list(pred.histories[i])) - for i in range(gold.size): - if gold._states[i].loss == 0: - g_scores.append(gold._states[i].score) - g_hist.append(list(gold.histories[i])) - - all_probs = _softmax(p_scores + g_scores) - p_probs = all_probs[:len(p_scores)] - g_probs_all = all_probs[len(p_scores):] - g_probs = _softmax(g_scores) - - self.cost = pred.loss - self.delta = d - self.p_hist = p_hist - self.g_hist = g_hist - # TODO: These variables are misnamed! These are the gradients of the loss. - self.p_probs = p_probs - # Intuition here: - # The gradient of the loss is: - # P(model) - P(truth) - # Normally, P(truth) is 1 for the gold - # But, if we want to do the "partial credit" scheme, we want - # to create a distribution over the gold, proportional to the scores - # awarded. - self.g_probs = [x-y for x, y in zip(g_probs_all, g_probs)] - - -def _softmax(nums): - if not nums: - return [] - max_ = max(nums) - nums = [(exp(n-max_) if n is not None else None) for n in nums] - Z = sum(n for n in nums if n is not None) - return [(n/Z if n is not None else None) for n in nums] diff --git a/thinc/extra/tests/__init__.py b/thinc/extra/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/thinc/extra/tests/c_test_search.pyx b/thinc/extra/tests/c_test_search.pyx deleted file mode 100644 index 81327f5a9..000000000 --- a/thinc/extra/tests/c_test_search.pyx +++ /dev/null @@ -1,83 +0,0 @@ -# cython: profile=False -from cymem.cymem cimport Pool - -from thinc.extra.search cimport Beam -from thinc.typedefs cimport class_t, weight_t - - -cdef struct TestState: - int length - int x - Py_UNICODE* string - - -cdef int transition(void* dest, void* src, class_t clas, void* extra_args) except -1: - dest_state = dest - src_state = src - dest_state.length = src_state.length - dest_state.x = src_state.x - dest_state.x += clas - if extra_args != NULL: - dest_state.string = extra_args - else: - dest_state.string = src_state.string - - -cdef void* initialize(Pool mem, int n, void* extra_args) except NULL: - state = mem.alloc(1, sizeof(TestState)) - state.length = n - state.x = 1 - if extra_args == NULL: - state.string = 'default' - else: - state.string = extra_args - return state - - -cdef int destroy(Pool mem, void* state, void* extra_args) except -1: - state = state - mem.free(state) - - -def test_init(nr_class, beam_width): - b = Beam(nr_class, beam_width) - assert b.size == 1 - assert b.width == beam_width - assert b.nr_class == nr_class - - -def test_initialize(nr_class, beam_width, length): - b = Beam(nr_class, beam_width) - b.initialize(initialize, destroy, length, NULL) - for i in range(b.width): - s = b.at(i) - assert s.length == length, s.length - assert s.string == 'default' - - -def test_initialize_extra(nr_class, beam_width, length, unicode extra): - b = Beam(nr_class, beam_width) - b.initialize(initialize, destroy, length, extra) - for i in range(b.width): - s = b.at(i) - assert s.length == length - - -def test_transition(nr_class=3, beam_width=6, length=3): - b = Beam(nr_class, beam_width) - b.initialize(initialize, destroy, length, NULL) - b.set_cell(0, 2, 30, True, 0) - b.set_cell(0, 1, 42, False, 0) - b.advance(transition, NULL, NULL) - assert b.size == 1, b.size - assert b.score == 30, b.score - s = b.at(0) - assert s.x == 3 - assert b._states[0].score == 30, b._states[0].score - b.set_cell(0, 1, 10, True, 0) - b.set_cell(0, 2, 20, True, 0) - b.advance(transition, NULL, NULL) - assert b._states[0].score == 50, b._states[0].score - assert b._states[1].score == 40 - s = b.at(0) - assert s.x == 5 diff --git a/thinc/layers/strings2arrays.py b/thinc/layers/strings2arrays.py index ed40b1e88..eba2c983d 100644 --- a/thinc/layers/strings2arrays.py +++ b/thinc/layers/strings2arrays.py @@ -1,3 +1,4 @@ +from ctypes import c_uint64 from typing import Callable, List, Sequence, Tuple from murmurhash import hash_unicode @@ -17,9 +18,9 @@ def strings2arrays() -> Model[InT, OutT]: def forward(model: Model[InT, OutT], Xs: InT, is_train: bool) -> Tuple[OutT, Callable]: - hashes = model.ops.asarray2i( - [[hash_unicode(word) for word in X] for X in Xs], dtype="int32" - ) + # Cast 32-bit (signed) integer to 64-bit unsigned, since such casting + # is deprecated in NumPy. + hashes = [[c_uint64(hash_unicode(word)).value for word in X] for X in Xs] hash_arrays = [model.ops.asarray1i(h, dtype="uint64") for h in hashes] arrays = [model.ops.reshape2i(array, -1, 1) for array in hash_arrays] diff --git a/thinc/optimizers.py b/thinc/optimizers.py index 4b4eca2b6..071ad4e85 100644 --- a/thinc/optimizers.py +++ b/thinc/optimizers.py @@ -1,14 +1,16 @@ +import itertools import math from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union, cast +from types import GeneratorType +from typing import Any, Dict, List, Optional, Tuple, Union, cast from .backends import get_array_ops from .config import registry +from .schedules import Schedule, constant from .types import FloatsXd, Generator KeyT = Tuple[int, str] -FloatOrSeq = Union[float, List[float], Generator] -IntOrSeq = Union[int, List[int], Generator] +ScheduleT = Union[float, List[float], Generator, Schedule] SGD_DEFAULTS: Dict[str, Union[float, bool, int]] = { "L2": 0.0, @@ -30,14 +32,14 @@ @registry.optimizers("RAdam.v1") def RAdam( - learn_rate: FloatOrSeq = ADAM_DEFAULTS["learn_rate"], + learn_rate: ScheduleT = ADAM_DEFAULTS["learn_rate"], *, - beta1: FloatOrSeq = ADAM_DEFAULTS["beta1"], - beta2: FloatOrSeq = ADAM_DEFAULTS["beta2"], - eps: FloatOrSeq = ADAM_DEFAULTS["eps"], - L2: FloatOrSeq = ADAM_DEFAULTS["L2"], + beta1: ScheduleT = ADAM_DEFAULTS["beta1"], + beta2: ScheduleT = ADAM_DEFAULTS["beta2"], + eps: ScheduleT = ADAM_DEFAULTS["eps"], + L2: ScheduleT = ADAM_DEFAULTS["L2"], L2_is_weight_decay: bool = cast(bool, ADAM_DEFAULTS["L2_is_weight_decay"]), - grad_clip: FloatOrSeq = ADAM_DEFAULTS["grad_clip"], + grad_clip: ScheduleT = ADAM_DEFAULTS["grad_clip"], use_averages: bool = True, ): return Optimizer( @@ -55,13 +57,13 @@ def RAdam( @registry.optimizers("Adam.v1") def Adam( - learn_rate: FloatOrSeq = ADAM_DEFAULTS["learn_rate"], + learn_rate: ScheduleT = ADAM_DEFAULTS["learn_rate"], *, - L2: FloatOrSeq = ADAM_DEFAULTS["L2"], - beta1: FloatOrSeq = ADAM_DEFAULTS["beta1"], - beta2: FloatOrSeq = ADAM_DEFAULTS["beta2"], - eps: FloatOrSeq = ADAM_DEFAULTS["eps"], - grad_clip: FloatOrSeq = ADAM_DEFAULTS["grad_clip"], + L2: ScheduleT = ADAM_DEFAULTS["L2"], + beta1: ScheduleT = ADAM_DEFAULTS["beta1"], + beta2: ScheduleT = ADAM_DEFAULTS["beta2"], + eps: ScheduleT = ADAM_DEFAULTS["eps"], + grad_clip: ScheduleT = ADAM_DEFAULTS["grad_clip"], L2_is_weight_decay: bool = cast(bool, ADAM_DEFAULTS["L2_is_weight_decay"]), use_averages: bool = True, ): @@ -80,10 +82,10 @@ def Adam( @registry.optimizers("SGD.v1") def SGD( - learn_rate: FloatOrSeq, + learn_rate: ScheduleT, *, - L2: FloatOrSeq = SGD_DEFAULTS["L2"], - grad_clip: FloatOrSeq = SGD_DEFAULTS["grad_clip"], + L2: ScheduleT = SGD_DEFAULTS["L2"], + grad_clip: ScheduleT = SGD_DEFAULTS["grad_clip"], L2_is_weight_decay: bool = cast(bool, SGD_DEFAULTS["L2_is_weight_decay"]), use_averages: bool = True, ): @@ -109,15 +111,17 @@ class Optimizer(object): schedules: Dict[str, Generator] nr_update: Dict[KeyT, int] last_seen: Dict[KeyT, int] - grad_clip: float - learn_rate: float - b1: float - b2: float - eps: float - L2: float + grad_clip: Schedule + learn_rate: Schedule + b1: Schedule + b2: Schedule + eps: Schedule + L2: Schedule use_radam: bool L2_is_weight_decay: bool _radam_buffer: List[List[Optional[FloatsXd]]] + _step: int + _last_score: Optional[Tuple[int, float]] # This "locks" the class, so we get an error if you try to assign to # an unexpected variable. @@ -137,17 +141,19 @@ class Optimizer(object): "use_radam", "L2_is_weight_decay", "_radam_buffer", + "_step", + "_last_score", ] def __init__( self, - learn_rate: FloatOrSeq, + learn_rate: ScheduleT, *, - L2: FloatOrSeq = ADAM_DEFAULTS["L2"], - beta1: FloatOrSeq = ADAM_DEFAULTS["beta1"], - beta2: FloatOrSeq = ADAM_DEFAULTS["beta2"], - eps: FloatOrSeq = ADAM_DEFAULTS["eps"], - grad_clip: FloatOrSeq = ADAM_DEFAULTS["grad_clip"], + L2: ScheduleT = ADAM_DEFAULTS["L2"], + beta1: ScheduleT = ADAM_DEFAULTS["beta1"], + beta2: ScheduleT = ADAM_DEFAULTS["beta2"], + eps: ScheduleT = ADAM_DEFAULTS["eps"], + grad_clip: ScheduleT = ADAM_DEFAULTS["grad_clip"], use_averages: bool = True, use_radam: bool = False, L2_is_weight_decay: bool = True, @@ -166,13 +172,14 @@ def __init__( L2_is_weight_decay (bool): Whether to interpret the L2 parameter as a weight decay term, in the style of the AdamW optimizer. """ + self._step = 0 + self._last_score = None self.mom1 = {} self.mom2 = {} if use_averages: self.averages = {} else: self.averages = None - self.schedules = {} self.nr_update = defaultdict(int) self.last_seen = defaultdict(int) self._set_attr_or_schedule("grad_clip", grad_clip) @@ -187,24 +194,38 @@ def __init__( def _set_attr_or_schedule(self, name, value): if isinstance(value, (float, bool, int)): + setattr(self, name, constant(value)) + elif isinstance(value, list): + value = iter(value) + setattr(self, name, _wrap_generator(name, value)) + elif isinstance(value, GeneratorType): + setattr(self, name, _wrap_generator(name, value)) + elif isinstance(value, Schedule): setattr(self, name, value) else: - if isinstance(value, list): - value = iter(value) - self.schedules[name] = value - try: - setattr(self, name, next(value)) - except (StopIteration, TypeError) as e: - err = f"Invalid schedule for '{name}' ({type(value)})\n{e}" - raise ValueError(err) + err = f"Invalid schedule for '{name}' ({type(value)})" + raise ValueError(err) def step_schedules(self): - for key, schedule in self.schedules.items(): - try: - value = next(schedule) - except StopIteration: # schedule exhausted, use last value - value = getattr(self, key) - setattr(self, key, value) + self._step += 1 + + @property + def last_score(self) -> Optional[Tuple[int, float]]: + return self._last_score + + @last_score.setter + def last_score(self, score: float): + self._last_score = (self._step, score) + + @property + def step(self) -> int: + return self._step + + def _schedule_args(self, key: KeyT) -> Dict[str, Any]: + return { + "key": key, + "last_score": self.last_score, + } def __call__( self, @@ -219,28 +240,42 @@ def __call__( """ if len(gradient) < 1: return weights, gradient + ops = get_array_ops(weights) self.nr_update[key] += 1 nr_upd = self.nr_update[key] - if self.L2 != 0 and not self.L2_is_weight_decay: - gradient += self.L2 * weights - if self.grad_clip: - gradient = ops.clip_gradient(gradient, self.grad_clip) + schedule_args = self._schedule_args(key) + + if self.L2(self.step, **schedule_args) != 0 and not self.L2_is_weight_decay: + gradient += self.L2(self.step, **schedule_args) * weights + if self.grad_clip(self.step, **schedule_args): + gradient = ops.clip_gradient( + gradient, + self.grad_clip(self.step, **schedule_args), + ) if self.use_radam: weights, gradient = self._radam( ops, weights, gradient, lr_scale, key, nr_upd ) - elif self.b1 > 0.0 and self.b2 > 0.0: + elif ( + self.b1(self.step, **schedule_args) > 0.0 + and self.b2(self.step, **schedule_args) > 0.0 + ): weights, gradient = self._adam( ops, weights, gradient, lr_scale, key, nr_upd ) - elif self.b2 > 0.0: # pragma: no cover + elif self.b2(self.step, **schedule_args) > 0.0: # pragma: no cover raise NotImplementedError # TODO: error message else: - weights -= lr_scale * self.learn_rate * gradient + weights -= lr_scale * self.learn_rate(self.step, **schedule_args) * gradient gradient *= 0 - if self.L2 != 0 and self.L2_is_weight_decay: - weights -= lr_scale * self.learn_rate * self.L2 * weights + if self.L2(self.step, **schedule_args) != 0 and self.L2_is_weight_decay: + weights -= ( + lr_scale + * self.learn_rate(self.step, **schedule_args) + * self.L2(self.step, **schedule_args) + * weights + ) if self.averages is not None: if key not in self.averages: self.averages[key] = ops.alloc(weights.shape, dtype="float32") @@ -256,6 +291,8 @@ def _radam(self, ops, weights, grad, lr_scale, key, nr_upd): weights_1D = ops.reshape1f(weights, weights.size) gradient_1D = ops.reshape1f(grad, grad.size) + schedule_args = self._schedule_args(key) + # While we port from the pytorch implementation, keep some of the same # naming state = { @@ -264,9 +301,12 @@ def _radam(self, ops, weights, grad, lr_scale, key, nr_upd): "exp_avg_sq": self.mom2[key], } group = { - "lr": self.learn_rate, - "betas": [self.b1, self.b2], - "eps": self.eps, + "lr": self.learn_rate(self.step, **schedule_args), + "betas": [ + self.b1(self.step, **schedule_args), + self.b2(self.step, **schedule_args), + ], + "eps": self.eps(self.step, **schedule_args), "weight_decay": 0.0, "buffer": self._radam_buffer, } @@ -328,18 +368,21 @@ def _radam(self, ops, weights, grad, lr_scale, key, nr_upd): def _adam(self, ops, weights, gradient, lr_scale, key, nr_upd): weights_1D = ops.reshape1f(weights, weights.size) gradient_1D = ops.reshape1f(gradient, gradient.size) + + schedule_args = self._schedule_args(key) + if key not in self.mom1: self.mom1[key] = ops.alloc1f(weights.size) if key not in self.mom2: self.mom2[key] = ops.alloc1f(weights.size) mom1 = self.mom1[key] mom2 = self.mom2[key] - b1 = self.b1 - b2 = self.b2 + b1 = self.b1(self.step, **schedule_args) + b2 = self.b2(self.step, **schedule_args) fix1 = 1.0 - (b1**nr_upd) fix2 = 1.0 - (b2**nr_upd) - lr = self.learn_rate * fix2**0.5 / fix1 - eps = self.eps + lr = self.learn_rate(self.step, **schedule_args) * fix2**0.5 / fix1 + eps = self.eps(self.step, **schedule_args) # needs to be 1D going into the adam function weights_1D, gradient_1D, mom1, mom2 = ops.adam( weights_1D, gradient_1D, mom1, mom2, b1, b2, eps, lr * lr_scale @@ -352,4 +395,49 @@ def _adam(self, ops, weights, gradient, lr_scale, key, nr_upd): ) +def _wrap_generator(attr_name: str, generator: Generator) -> Schedule[Any]: + try: + peek = next(generator) + except (StopIteration, TypeError) as e: + err = f"Invalid schedule for '{attr_name}' ({type(generator)})\n{e}" + raise ValueError(err) + return Schedule( + "wrap_generator", + _wrap_generator_schedule, + attrs={ + "attr_name": attr_name, + "last_step": -1, + "last_value": peek, + "generator": itertools.chain([peek], generator), + }, + ) + + +def _wrap_generator_schedule(schedule: Schedule, step, **kwargs) -> float: + attr_name = schedule.attrs["attr_name"] + last_step = schedule.attrs["last_step"] + last_value = schedule.attrs["last_value"] + generator = schedule.attrs["generator"] + + if step < last_step: + raise ValueError( + f"'step' of the generator-based schedule for {attr_name} must not decrease" + ) + + # Ensure that we have a value when we didn't step or when the + # generator is exhausted. + value = last_value + + for i in range(step - last_step): + try: + value = next(generator) + except StopIteration: # schedule exhausted, use last value + break + + schedule.attrs["last_step"] = step + schedule.attrs["last_value"] = value + + return value + + __all__ = ["Adam", "RAdam", "SGD", "Optimizer", "ADAM_DEFAULTS", "SGD_DEFAULTS"] diff --git a/thinc/schedules.py b/thinc/schedules.py index c13868a5d..2f99a536a 100644 --- a/thinc/schedules.py +++ b/thinc/schedules.py @@ -1,33 +1,107 @@ """Generators that provide different rates, schedules, decays or series.""" -from typing import Iterable +import itertools +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generator, Generic, Optional, Tuple, TypeVar import numpy from .config import registry +OutT = TypeVar("OutT") + + +class Schedule(Generic[OutT]): + """Class for implementing Thinc schedules.""" + + name: str + _schedule: Callable + _attrs: Dict[str, Any] + + __slots__ = ["name", "_schedule", "_attrs"] + + def __init__( + self, name: str, schedule: Callable, *, attrs: Dict[str, Any] = {} + ) -> None: + """Initialize a new schedule. + + name (str): The name of the schedule type. + schedule (Callable): The schedule function. + """ + self.name = name + self._schedule = schedule + self._attrs = dict(attrs) + + def __call__(self, step: int, **extra) -> OutT: + """Compute the schedule for a given step.""" + + if step < 0: + raise ValueError(f"Step must be non-negative, was: {step}") + + return self._schedule(self, step, **extra) + + @property + def attrs(self): + """Schedule attributes.""" + return self._attrs + + def to_generator( + self, start: int = 0, step_size=1, **extra + ) -> Generator[OutT, None, None]: + """Turn the schedule into a generator. + + start (int): The schedule initial step. + step_size (int): The amount to increase the step for each generated value. + **extra: Additional arguments that are passed to the schedule. + RETURNS (Generator[OutT, None, None]): The generator. + """ + if start < 0: + raise ValueError(f"Schedule start must be non-negative, was: {start}") + if step_size < 0: + raise ValueError(f"Step size must be non-negative, was: {step_size}") + + def generate(): + for step in itertools.count(start, step_size): + yield self(step, **extra) + + return generate() + @registry.schedules("constant_then.v1") -def constant_then( - rate: float, steps: int, schedule: Iterable[float] -) -> Iterable[float]: +def constant_then(rate: OutT, steps: int, schedule: Schedule[OutT]) -> Schedule[OutT]: """Yield a constant rate for N steps, before starting a schedule.""" - for i in range(steps): - yield rate - for value in schedule: - yield value + return Schedule( + "constant_then", + _constant_then_schedule, + attrs={"rate": rate, "steps": steps, "schedule": schedule}, + ) + + +def _constant_then_schedule(schedule: Schedule, step: int, **kwargs) -> float: + rate = schedule.attrs["rate"] + steps = schedule.attrs["steps"] + schedule = schedule.attrs["schedule"] + + if step < steps: + return rate + else: + return schedule(step=step, **kwargs) @registry.schedules("constant.v1") -def constant(rate: float) -> Iterable[float]: +def constant(rate: OutT) -> Schedule[OutT]: """Yield a constant rate.""" - while True: - yield rate + return Schedule("constant", _constant_schedule, attrs={"rate": rate}) + + +def _constant_schedule(schedule: Schedule, step: int, **kwargs) -> float: + rate = schedule.attrs["rate"] + return rate @registry.schedules("decaying.v1") -def decaying(base_rate: float, decay: float, *, t: int = 0) -> Iterable[float]: +def decaying(base_rate: float, decay: float, *, t: float = 0.0) -> Schedule[float]: """Yield an infinite series of linearly decaying values, - following the schedule: base_rate * 1 / (1 + decay * t) + following the schedule: base_rate * 1 / (1 + decay * (t + step)) EXAMPLE: >>> learn_rates = decaying(0.001, 1e-4) @@ -36,15 +110,24 @@ def decaying(base_rate: float, decay: float, *, t: int = 0) -> Iterable[float]: >>> next(learn_rates) 0.00999 """ - while True: - yield base_rate * (1.0 / (1.0 + decay * t)) - t += 1 + return Schedule( + "decaying", + _decaying_schedule, + attrs={"base_rate": base_rate, "decay": decay, "t": t}, + ) + + +def _decaying_schedule(schedule: Schedule, step: int, **kwargs) -> float: + base_rate = schedule.attrs["base_rate"] + decay = schedule.attrs["decay"] + t = schedule.attrs["t"] + return base_rate * (1.0 / (1.0 + decay * (step + t))) @registry.schedules("compounding.v1") def compounding( start: float, stop: float, compound: float, *, t: float = 0.0 -) -> Iterable[float]: +) -> Schedule[float]: """Yield an infinite series of compounding values. Each time the generator is called, a value is produced by multiplying the previous value by the compound rate. @@ -55,16 +138,128 @@ def compounding( >>> assert next(sizes) == 1 * 1.5 >>> assert next(sizes) == 1.5 * 1.5 """ - curr = float(start) - while True: - yield _clip(curr, start, stop) - curr *= compound + return Schedule( + "compounding", + _compounding_schedule, + attrs={"start": start, "stop": stop, "compound": compound, "t": t}, + ) + + +def _compounding_schedule(schedule: Schedule, step: int, **kwargs) -> float: + start = schedule.attrs["start"] + stop = schedule.attrs["stop"] + compound = schedule.attrs["compound"] + t = schedule.attrs["t"] + return _clip(start * (compound ** (step + t)), start, stop) def _clip(value: float, start: float, stop: float) -> float: return max(value, stop) if (start > stop) else min(value, stop) +@registry.schedules("plateau.v1") +def plateau( + max_patience: int, scale: float, schedule: Schedule[float] +) -> Schedule[float]: + + """Yields values from the wrapped schedule, exponentially scaled by the + number of times optimization has plateaued. The caller must pass model + evaluation scores through the last_score argument for the scaling to be + adjusted. The last evaluation score is passed through the last_score argument + as a tuple (last_score_step, last_score). This tuple indicates when a model + was last evaluated (last_score_step) and with what score (last_score). + + max_patience (int): the number of evaluations without improvement when + we consider the model to have plateaued. + scale (float): scaling of the inner schedule (scale**n_plateaus * inner). + schedule (Schedule[float]): the schedule to wrap. + """ + + return Schedule( + "plateau", + _plateau_schedule, + attrs={ + "scale": scale, + "max_patience": max_patience, + "schedule": schedule, + "state": _PlateauState( + best_score=None, last_score_step=None, patience=0, n_plateaus=0 + ), + }, + ) + + +def _plateau_schedule( + schedule: Schedule, + step: int, + *, + last_score: Optional[Tuple[int, float]] = None, + **kwargs, +) -> float: + inner_schedule: Schedule[float] = schedule.attrs["schedule"] + max_patience: int = schedule.attrs["max_patience"] + scale: float = schedule.attrs["scale"] + state: _PlateauState = schedule.attrs["state"] + + if last_score is None: + return (scale**state.n_plateaus) * inner_schedule( + step=step, last_score=last_score, **kwargs + ) + + last_score_step, last_score_ = last_score + + if ( + state.best_score is None + or state.last_score_step is None + or last_score_ > state.best_score + ): + state.best_score = last_score_ + state.patience = 0 + elif last_score_step < state.last_score_step: + raise ValueError( + f"Expected score with step >= {state.last_score_step}, was: {last_score_step}" + ) + elif last_score_step > state.last_score_step: + # If the score didn't improve and we are not seeing the last + # score again, we may be at a plateau, so increase patience. + state.patience += 1 + + # If we are at the maximum patience, we consider the optimization + # to have reached a plateau. + if state.patience == max_patience: + state.n_plateaus += 1 + state.patience = 0 + + state.last_score_step = last_score_step + + return (scale**state.n_plateaus) * inner_schedule( + step=step, last_score=last_score, **kwargs + ) + + +@dataclass +class _PlateauState: + """Plateau schedule state. + + best_score (Optional[float]): the best score so far, or None when no + score has been observed. + last_score_step (Optional[int]): the step of the last score that was + observed. + patience (int): the number of scores so far which do not improve over + the best score (reset after reaching the maximum patience). + n_plateaus (int): the number of times the maximum patience has been + reached. + """ + + best_score: Optional[float] + last_score_step: Optional[int] + patience: int + n_plateaus: int + + # @dataclass(slots=True) is only supported in Python >= 3.10 + __slots__ = ["best_score", "last_score_step", "patience", "n_plateaus"] + + @registry.schedules("slanted_triangular.v1") def slanted_triangular( max_rate: float, @@ -72,52 +267,90 @@ def slanted_triangular( *, cut_frac: float = 0.1, ratio: int = 32, - decay: float = 1.0, t: float = 0.0, -) -> Iterable[float]: +) -> Schedule[float]: """Yield an infinite series of values according to Howard and Ruder's "slanted triangular learning rate" schedule. """ cut = int(num_steps * cut_frac) - while True: - t += 1 - if t < cut: - p = t / cut - else: - p = 1 - ((t - cut) / (cut * (1 / cut_frac - 1))) - learn_rate = max_rate * (1 + p * (ratio - 1)) * (1 / ratio) - yield learn_rate + return Schedule( + "slanted_triangular", + _slanted_triangular_schedule, + attrs={ + "max_rate": max_rate, + "cut": cut, + "cut_frac": cut_frac, + "ratio": ratio, + "t": t, + }, + ) + + +def _slanted_triangular_schedule(schedule: Schedule, step: int, **kwargs) -> float: + max_rate = schedule.attrs["max_rate"] + cut = schedule.attrs["cut"] + cut_frac = schedule.attrs["cut_frac"] + ratio = schedule.attrs["ratio"] + t = schedule.attrs["t"] + + t_step = step + t + 1.0 + if t_step < cut: + p = t_step / cut + else: + p = 1 - ((t_step - cut) / (cut * (1 / cut_frac - 1))) + return max_rate * (1 + p * (ratio - 1)) * (1 / ratio) @registry.schedules("warmup_linear.v1") def warmup_linear( initial_rate: float, warmup_steps: int, total_steps: int -) -> Iterable[float]: +) -> Schedule[float]: """Generate a series, starting from an initial rate, and then with a warmup period, and then a linear decline. Used for learning rates. """ - step = 0 - while True: - if step < warmup_steps: - factor = step / max(1, warmup_steps) - else: - factor = max( - 0.0, (total_steps - step) / max(1.0, total_steps - warmup_steps) - ) - yield factor * initial_rate - step += 1 + return Schedule( + "warmup_linear", + _warmup_linear_schedule, + attrs={ + "initial_rate": initial_rate, + "warmup_steps": warmup_steps, + "total_steps": total_steps, + }, + ) + + +def _warmup_linear_schedule(schedule: Schedule, step: int, **kwargs) -> float: + initial_rate = schedule.attrs["initial_rate"] + warmup_steps = schedule.attrs["warmup_steps"] + total_steps = schedule.attrs["total_steps"] + + if step < warmup_steps: + factor = step / max(1, warmup_steps) + else: + factor = max(0.0, (total_steps - step) / max(1.0, total_steps - warmup_steps)) + return factor * initial_rate @registry.schedules("cyclic_triangular.v1") -def cyclic_triangular(min_lr: float, max_lr: float, period: int) -> Iterable[float]: - it = 1 - while True: - # https://towardsdatascience.com/adaptive-and-cyclical-learning-rates-using-pytorch-2bf904d18dee - cycle = numpy.floor(1 + it / (2 * period)) - x = numpy.abs(it / period - 2 * cycle + 1) - relative = max(0, 1 - x) - yield min_lr + (max_lr - min_lr) * relative - it += 1 +def cyclic_triangular(min_lr: float, max_lr: float, period: int) -> Schedule[float]: + return Schedule( + "cyclic_triangular", + _cyclic_triangular_schedule, + attrs={"min_lr": min_lr, "max_lr": max_lr, "period": period}, + ) + + +def _cyclic_triangular_schedule(schedule: Schedule, step: int, **kwargs) -> float: + min_lr = schedule.attrs["min_lr"] + max_lr = schedule.attrs["max_lr"] + period = schedule.attrs["period"] + + it = step + 1 + # https://towardsdatascience.com/adaptive-and-cyclical-learning-rates-using-pytorch-2bf904d18dee + cycle = numpy.floor(1 + it / (2 * period)) + x = numpy.abs(it / period - 2 * cycle + 1) + relative = max(0, 1 - x) + return min_lr + (max_lr - min_lr) * relative __all__ = [ diff --git a/thinc/shims/torchscript.py b/thinc/shims/torchscript.py index 6c05c8a9b..9d413f93a 100644 --- a/thinc/shims/torchscript.py +++ b/thinc/shims/torchscript.py @@ -30,7 +30,7 @@ class TorchScriptShim(PyTorchShim): def __init__( self, - model: Optional["torch.ScriptModule"], + model: Optional["torch.jit.ScriptModule"], config=None, optimizer: Any = None, mixed_precision: bool = False, diff --git a/thinc/extra/__init__.py b/thinc/tests/backends/_apple_blas/__init__.py similarity index 100% rename from thinc/extra/__init__.py rename to thinc/tests/backends/_apple_blas/__init__.py diff --git a/thinc/tests/backends/_apple_blas/test_gemm.py b/thinc/tests/backends/_apple_blas/test_gemm.py new file mode 100644 index 000000000..10e662110 --- /dev/null +++ b/thinc/tests/backends/_apple_blas/test_gemm.py @@ -0,0 +1,79 @@ +import numpy +import pytest + +from thinc.compat import has_apple_ops + +try: + import thinc.backends._accelerate as accelerate +except: + pass + + +@pytest.mark.skipif(not has_apple_ops, reason="Apple ops not available") +def test_basic_sgemm(): + A = numpy.random.randn(5, 4).astype("f") + B = numpy.random.randn(4, 7).astype("f") + C = accelerate.gemm(A, B) + assert C.shape == (A.shape[0], B.shape[1]) + + C_out = numpy.empty((5, 7), dtype="f") + accelerate.gemm(A, B, out=C_out) + + numpy.testing.assert_allclose(C, C_out) + + +@pytest.mark.skipif(not has_apple_ops, reason="Apple ops not available") +def test_incorrect_output_size(): + A = numpy.ndarray((5, 4), dtype="f") + B = numpy.ndarray((4, 7), dtype="f") + + with pytest.raises(ValueError, match=r"Shape mismatch for output matrix"): + accelerate.gemm(A, B, out=numpy.ndarray((3, 7), dtype="f")) + + with pytest.raises(ValueError, match=r"Shape mismatch for output matrix"): + accelerate.gemm(A, B, out=numpy.ndarray((5, 3), dtype="f")) + + +@pytest.mark.skipif(not has_apple_ops, reason="Apple ops not available") +@pytest.mark.parametrize( + "A_shape,B_shape,transA,transB", + [ + [(0, 0), (0, 0), False, False], + [(0, 0), (0, 0), True, False], + [(0, 0), (0, 0), False, True], + [(0, 0), (0, 0), True, True], + [(0, 5), (5, 0), False, False], + [(5, 0), (5, 0), False, True], + [(5, 0), (5, 0), True, False], + ], +) +def test_zero_size(A_shape, B_shape, transA, transB): + A = numpy.ndarray(A_shape, dtype="f") + B = numpy.ndarray(B_shape, dtype="f") + if not transA and not transB: + C = numpy.dot(A, B) + elif transA: + C = numpy.dot(A.T, B) + elif transB: + C = numpy.dot(A, B.T) + else: + C = numpy.dot(A.T, B.T) + C_ = accelerate.gemm(A, B, trans1=transA, trans2=transB) + assert C.shape == C_.shape + + +@pytest.mark.skipif(not has_apple_ops, reason="Apple ops not available") +@pytest.mark.parametrize( + "A_shape,B_shape,transA,transB", + [ + [(4, 5), (4, 5), False, False], + [(5, 4), (4, 5), True, False], + [(4, 5), (5, 4), False, True], + [(5, 4), (5, 4), True, True], + ], +) +def test_incorrect_shapes(A_shape, B_shape, transA, transB): + A = numpy.ndarray(A_shape, dtype="f") + B = numpy.ndarray(B_shape, dtype="f") + with pytest.raises(ValueError, match=r"Shape mismatch"): + accelerate.gemm(A, B, trans1=transA, trans2=transB) diff --git a/thinc/tests/backends/test_mps_ops.py b/thinc/tests/backends/test_mps_ops.py new file mode 100644 index 000000000..1bd5838b1 --- /dev/null +++ b/thinc/tests/backends/test_mps_ops.py @@ -0,0 +1,11 @@ +from thinc.api import NumpyOps, get_ops +from thinc.compat import has_apple_ops + + +def test_mps_ops_inherits_apple_ops(): + ops = get_ops("mps") + assert isinstance(ops, NumpyOps) + if has_apple_ops: + # We can't import AppleOps directly, because its' not + # available on non-Darwin systems. + assert "AppleOps" in [base.__name__ for base in type(ops).__bases__] diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index 9f03c0438..7cf4a935d 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -1403,7 +1403,7 @@ def test_get_ops(): # If Apple ops are available, "cpu" should return AppleOps or # NumpyOps otherwise. try: - from thinc_apple_ops import AppleOps + from thinc.backends.apple_ops import AppleOps assert isinstance(get_ops("cpu"), AppleOps) except ImportError: diff --git a/thinc/tests/extra/__init__.py b/thinc/tests/extra/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/thinc/tests/extra/test_beam_search.py b/thinc/tests/extra/test_beam_search.py deleted file mode 100644 index ab7ab9f11..000000000 --- a/thinc/tests/extra/test_beam_search.py +++ /dev/null @@ -1,5 +0,0 @@ -from thinc.extra.search import MaxViolation - - -def test_init_violn(): - MaxViolation() diff --git a/thinc/tests/layers/test_basic_tagger.py b/thinc/tests/layers/test_basic_tagger.py index 855a6d6ad..3bc772940 100644 --- a/thinc/tests/layers/test_basic_tagger.py +++ b/thinc/tests/layers/test_basic_tagger.py @@ -60,7 +60,6 @@ def get_shuffled_batches(Xs, Ys, batch_size): yield list(batch_X), list(batch_Y) -@pytest.mark.slow @pytest.mark.parametrize( ("depth", "width", "vector_width", "nb_epoch"), [(2, 32, 16, 5)] ) diff --git a/thinc/tests/test_config.py b/thinc/tests/test_config.py index a3f4ede46..dc1c9a0af 100644 --- a/thinc/tests/test_config.py +++ b/thinc/tests/test_config.py @@ -183,9 +183,8 @@ def decaying(base_rate: float, repeat: int) -> List[float]: return repeat * [base_rate] optimizer = my_registry.resolve(config)["optimizer"] - assert optimizer.b1 == 0.2 - assert "learn_rate" in optimizer.schedules - assert optimizer.learn_rate == 0.001 + assert optimizer.b1(step=optimizer._step, key=(0, "")) == 0.2 + assert optimizer.learn_rate(step=optimizer._step, key=(0, "")) == 0.001 def test_handle_generic_model_type(): diff --git a/thinc/tests/test_optimizers.py b/thinc/tests/test_optimizers.py index 4e336640b..57b5a27ff 100644 --- a/thinc/tests/test_optimizers.py +++ b/thinc/tests/test_optimizers.py @@ -2,6 +2,9 @@ import pytest from thinc.api import Optimizer, registry +from thinc.optimizers import KeyT, _wrap_generator + +STUB_KEY: KeyT = (0, "") def _test_schedule_valid(): @@ -30,6 +33,22 @@ def schedule_valid(request): return r_func(), r1, r2, r3 +@pytest.fixture( + params=[ + (lambda: 0.123, 0.123, 0.123, 0.123), + (lambda: (i for i in [0.2, 0.1, 0.4, 0.5, 0.6, 0.7, 0.8]), 0.2, 0.1, 0.4), + (lambda: (i for i in [0.333, 0.666]), 0.333, 0.666, 0.666), + (lambda: [0.9, 0.8, 0.7], 0.9, 0.8, 0.7), + (lambda: [0.0, 0.123], 0.0, 0.123, 0.123), + ], + scope="function", +) +def schedule_config_valid(request): + # Use lambda to prevent iterator from being consumed by first test + r_func, r1, r2, r3 = request.param + return r_func(), r1, r2, r3 + + @pytest.fixture( params=[ (lambda: "hello"), @@ -50,32 +69,32 @@ def test_optimizers_from_config(name): learn_rate = 0.123 cfg = {"@optimizers": name, "learn_rate": learn_rate} optimizer = registry.resolve({"config": cfg})["config"] - assert optimizer.learn_rate == learn_rate + assert optimizer.learn_rate(step=optimizer._step, key=STUB_KEY) == learn_rate -def test_optimizer_schedules_from_config(schedule_valid): - lr, lr_next1, lr_next2, lr_next3 = schedule_valid +def test_optimizer_schedules_from_config(schedule_config_valid): + lr, lr_next1, lr_next2, lr_next3 = schedule_config_valid cfg = {"@optimizers": "Adam.v1", "learn_rate": lr} optimizer = registry.resolve({"cfg": cfg})["cfg"] - assert optimizer.learn_rate == lr_next1 + assert optimizer.learn_rate(step=optimizer._step, key=STUB_KEY) == lr_next1 optimizer.step_schedules() - assert optimizer.learn_rate == lr_next2 + assert optimizer.learn_rate(step=optimizer._step, key=STUB_KEY) == lr_next2 optimizer.step_schedules() - assert optimizer.learn_rate == lr_next3 - optimizer.learn_rate = 1.0 - assert optimizer.learn_rate == 1.0 + assert optimizer.learn_rate(step=optimizer._step, key=STUB_KEY) == lr_next3 + optimizer.learn_rate = lambda *, step, key: 1.0 + assert optimizer.learn_rate(step=optimizer._step, key=STUB_KEY) == 1.0 def test_optimizer_schedules_valid(schedule_valid): lr, lr_next1, lr_next2, lr_next3 = schedule_valid optimizer = Optimizer(learn_rate=lr) - assert optimizer.learn_rate == lr_next1 + assert optimizer.learn_rate(step=optimizer._step, key=STUB_KEY) == lr_next1 optimizer.step_schedules() - assert optimizer.learn_rate == lr_next2 + assert optimizer.learn_rate(step=optimizer._step, key=STUB_KEY) == lr_next2 optimizer.step_schedules() - assert optimizer.learn_rate == lr_next3 - optimizer.learn_rate = 1.0 - assert optimizer.learn_rate == 1.0 + assert optimizer.learn_rate(step=optimizer._step, key=STUB_KEY) == lr_next3 + optimizer.learn_rate = lambda *, step, key: 1.0 + assert optimizer.learn_rate(step=optimizer._step, key=STUB_KEY) == 1.0 def test_optimizer_schedules_invalid(schedule_invalid): @@ -98,3 +117,31 @@ def test_optimizer_init(): optimizer((0, "x"), W, dW) optimizer = Optimizer(learn_rate=0.123, beta1=0.1, beta2=0.1) optimizer((1, "x"), W, dW) + + +def test_optimizer_last_score(): + optimizer = Optimizer( + learn_rate=0.123, + ) + + assert optimizer.last_score is None + optimizer.last_score = 1.0 + assert optimizer.last_score == (0, 1.0) + optimizer.step_schedules() + optimizer.step_schedules() + assert optimizer.last_score == (0, 1.0) + optimizer.last_score = 2.0 + assert optimizer.last_score == (2, 2.0) + + +def test_generator_schedule(): + s = _wrap_generator("test", iter([0.0, 1.0, 2.0, 3.0])) + assert s(step=0, key=STUB_KEY, last_score=None) == 0.0 + assert s(step=0, key=STUB_KEY, last_score=None) == 0.0 + assert s(step=1, key=STUB_KEY, last_score=None) == 1.0 + assert s(step=1, key=STUB_KEY, last_score=None) == 1.0 + assert s(step=3, key=STUB_KEY, last_score=None) == 3.0 + assert s(step=10, key=STUB_KEY, last_score=None) == 3.0 + + with pytest.raises(ValueError, match=r"must not decrease"): + s(step=1, key=STUB_KEY, last_score=None) diff --git a/thinc/tests/test_schedules.py b/thinc/tests/test_schedules.py index 31a8f4e3b..8eafe05e6 100644 --- a/thinc/tests/test_schedules.py +++ b/thinc/tests/test_schedules.py @@ -1,3 +1,7 @@ +from itertools import islice + +import pytest + from thinc.api import ( compounding, constant, @@ -7,64 +11,100 @@ slanted_triangular, warmup_linear, ) +from thinc.schedules import plateau def test_decaying_rate(): rates = decaying(0.001, 1e-4) - rate = next(rates) + rate = rates(step=0) assert rate == 0.001 - next_rate = next(rates) + next_rate = rates(step=1) assert next_rate < rate assert next_rate > 0 - assert next_rate > next(rates) + assert next_rate > rates(step=2) + + rates_offset = decaying(0.001, 1e-4, t=1.0) + assert rates(step=1) == rates_offset(step=0) + assert rates(step=2) == rates_offset(step=1) def test_compounding_rate(): rates = compounding(1, 16, 1.01) - rate0 = next(rates) + rate0 = rates(step=0) assert rate0 == 1.0 - rate1 = next(rates) - rate2 = next(rates) - rate3 = next(rates) + rate1 = rates(step=1) + rate2 = rates(step=2) + rate3 = rates(step=3) assert rate3 > rate2 > rate1 > rate0 assert (rate3 - rate2) > (rate2 - rate1) > (rate1 - rate0) + rates_offset = compounding(1, 16, 1.01, t=1.0) + assert rates(step=1) == rates_offset(step=0) + assert rates(step=2) == rates_offset(step=1) + def test_slanted_triangular_rate(): rates = slanted_triangular(1.0, 20.0, ratio=10) - rate0 = next(rates) + rate0 = rates(step=0) assert rate0 < 1.0 - rate1 = next(rates) + rate1 = rates(step=1) assert rate1 > rate0 - rate2 = next(rates) + rate2 = rates(step=2) assert rate2 < rate1 - rate3 = next(rates) + rate3 = rates(step=3) assert rate0 < rate3 < rate2 + rates_offset = slanted_triangular(1.0, 20.0, ratio=10, t=1.0) + assert rates(step=1) == rates_offset(step=0) + assert rates(step=2) == rates_offset(step=1) + def test_constant_then_schedule(): - rates = constant_then(1.0, 2, [100, 200]) - assert next(rates) == 1.0 - assert next(rates) == 1.0 - assert next(rates) == 100 - assert next(rates) == 200 + rates = constant_then(1.0, 2, constant(100)) + assert rates(step=0) == 1.0 + assert rates(step=1) == 1.0 + assert rates(step=2) == 100 + assert rates(step=3) == 100 def test_constant(): rates = constant(123) - assert next(rates) == 123 - assert next(rates) == 123 + assert rates(step=0, key=(0, "")) == 123 + assert rates(step=0, key=(0, "")) == 123 def test_warmup_linear(): rates = warmup_linear(1.0, 2, 10) expected = [0.0, 0.5, 1.0, 0.875, 0.75, 0.625, 0.5, 0.375, 0.25, 0.125, 0.0] for i in range(11): - assert next(rates) == expected[i] + assert rates(step=i, key=(0, "")) == expected[i] def test_cyclic_triangular(): rates = cyclic_triangular(0.1, 1.0, 2) expected = [0.55, 1.0, 0.55, 0.1, 0.55, 1.0, 0.55, 0.1, 0.55, 1.0] for i in range(10): - assert next(rates) == expected[i] + assert rates(step=i, key=(0, "")) == expected[i] + + +def test_plateau(): + schedule = plateau(2, 0.5, constant(1.0)) + assert schedule(step=0, last_score=None) == 1.0 + assert schedule(step=1, last_score=(1, 1.0)) == 1.0 # patience == 0 + assert schedule(step=2, last_score=(2, 1.0)) == 1.0 # patience == 1 + assert schedule(step=3, last_score=None) == 1.0 # patience == 1 + assert schedule(step=4, last_score=(4, 1.0)) == 0.5 # patience == 2, reset + assert schedule(step=5, last_score=(4, 1.0)) == 0.5 # patience == 0 + assert schedule(step=6, last_score=(6, 0.9)) == 0.5 # patience == 1 + assert schedule(step=7, last_score=(7, 2.0)) == 0.5 # patience == 0 + assert schedule(step=8, last_score=(8, 1.0)) == 0.5 # patience == 1 + assert schedule(step=9, last_score=(9, 2.0)) == 0.25 # patience == 2, reset + + with pytest.raises(ValueError, match=r"Expected score with step"): + schedule(step=1, last_score=(1, 1.0)) == 1.0 + + +def test_to_generator(): + rates = warmup_linear(1.0, 2, 10) + expected = [0.0, 0.5, 1.0, 0.875, 0.75, 0.625, 0.5, 0.375, 0.25, 0.125, 0.0] + assert list(islice(rates.to_generator(), len(expected))) == expected diff --git a/thinc/tests/test_types.py b/thinc/tests/test_types.py index 738a309f9..bf2740bbb 100644 --- a/thinc/tests/test_types.py +++ b/thinc/tests/test_types.py @@ -1,6 +1,17 @@ import numpy import pytest +from thinc.types import ( + Floats1d, + Floats2d, + Floats3d, + Floats4d, + Ints1d, + Ints2d, + Ints3d, + Ints4d, +) + try: from pydantic.v1 import ValidationError, create_model except ImportError: diff --git a/thinc/util.py b/thinc/util.py index 5ca928698..529faf875 100644 --- a/thinc/util.py +++ b/thinc/util.py @@ -9,6 +9,7 @@ from contextvars import ContextVar from dataclasses import dataclass from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -30,8 +31,10 @@ except ImportError: from pydantic import ValidationError, create_model # type: ignore +import numpy from wasabi import table +from . import types # noqa: E402 from .compat import ( cupy, cupy_from_dlpack, @@ -47,18 +50,15 @@ from .compat import mxnet as mx from .compat import tensorflow as tf from .compat import torch - -DATA_VALIDATION: ContextVar[bool] = ContextVar("DATA_VALIDATION", default=False) - -from typing import TYPE_CHECKING - -from . import types # noqa: E402 from .types import ArgsKwargs, ArrayXd, FloatsXd, IntsXd, Padded, Ragged # noqa: E402 if TYPE_CHECKING: from .api import Ops +DATA_VALIDATION: ContextVar[bool] = ContextVar("DATA_VALIDATION", default=False) + + def get_torch_default_device() -> "torch.device": if torch is None: raise ValueError("Cannot get default Torch device when Torch is not available.") diff --git a/website/docs/api-backends.md b/website/docs/api-backends.md index fc69a775d..853fada3b 100644 --- a/website/docs/api-backends.md +++ b/website/docs/api-backends.md @@ -17,16 +17,18 @@ specialized versions can be called for different backends. You can also create your own `Ops` subclasses with specialized routines for your layers, and use the [`set_current_ops`](#set_current_ops) function to change the default. -| Backend | CPU | GPU | TPU | Description | -| ---------- | :----------------: | :----------------: | :---------------: | ----------------------------------------------------------------------------------------------------- | -| `NumpyOps` | | | | Execute via `numpy`, [`blis`](https://github.com/explosion/cython-blis) (optional) and custom Cython. | -| `CupyOps` | | | | Execute via [`cupy`](https://cupy.chainer.org/) and custom CUDA. | +| Backend | CPU | GPU | TPU | Description | +| ---------- | :----------------: | :----------------: | :---------------: | ----------------------------------------------------------------------------------------------------------- | +| `AppleOps` | | | | Use AMX matrix multiplication units on Apple Silicon Macs. Added in Thinc 9.0. | +| `CupyOps` | | | | Execute via [`cupy`](https://cupy.chainer.org/) and custom CUDA. | +| `MPSOps` | | | | Use the GPU on Apple Silicon Macs for PyTorch models, use AMX matrix multiplication units for Thinc Models. | +| `NumpyOps` | | | | Execute via `numpy`, [`blis`](https://github.com/explosion/cython-blis) (optional) and custom Cython. | ## Ops {#ops tag="class"} -The `Ops` class is typically not used directly but via `NumpyOps` or `CupyOps`, -which are subclasses of `Ops` and implement a **more efficient subset of the -methods**. You also have access to the ops via the +The `Ops` class is typically not used directly but via `NumpyOps`, `AppleOps`, +`CupyOps` or `MPSOps`, which are subclasses of `Ops` and implement a **more +efficient subset of the methods**. You also have access to the ops via the [`Model.ops`](/docs/api-model#attributes) attribute. The documented methods below list which backends provide optimized and more efficient versions (indicated by ), and which use the default implementation. @@ -56,7 +58,7 @@ use_ops(blis_ops) | Name | Type | Description | | ------------- | ------------ | ---------------------------------------------------------------------------------------- | -| `name` | str | **Class attribute:** Backend name, `"numpy"` or `"cupy"`. | +| `name` | str | **Class attribute:** Backend name, `"numpy"`, `"apple"`, `"cupy"` or `"mps"`. | | `xp` | Xp | **Class attribute:** `numpy` or `cupy`. | | `device_type` | str | The device type to use, if available for the given backend: `"cpu"`, `"gpu"` or `"tpu"`. | | `device_id` | int | The device ID to use, if available for the given backend. | @@ -1553,7 +1555,7 @@ numpy_ops = get_ops("numpy") | Argument | Type | Description | | ----------- | ------------ | ----------------------------------------------------- | -| `ops` | str | `"numpy"` or `"cupy"`. | +| `ops` | str | `"numpy"`, `"apple"`, `"cupy"` or `"mps"`. | | `**kwargs` | | Optional arguments passed to [`Ops.__init__`](#init). | | **RETURNS** | Ops | The backend object. | @@ -1572,7 +1574,7 @@ with use_ops("cupy"): | Argument | Type | Description | | ---------- | ------------ | ----------------------------------------------------- | -| `ops` | str | `"numpy"` or `"cupy"`. | +| `ops` | str | `"numpy"`, `"apple"`, `"cupy"` or `"mps"`. | | `**kwargs` | | Optional arguments passed to [`Ops.__init__`](#init). | ### get_current_ops {#get_current_ops tag="function"} diff --git a/website/docs/api-model.md b/website/docs/api-model.md index 597f67ec9..193fd1acb 100644 --- a/website/docs/api-model.md +++ b/website/docs/api-model.md @@ -84,19 +84,19 @@ model = Model( ) ``` -| Argument | Type | Description | -| -------------- | ------------------------------------------- | --------------------------------------------------------------------------------------- | -| `name` | str | The name of the layer type. | -| `forward` | Callable | Function to compute the forward result and the backpropagation callback. | -| _keyword-only_ | | | -| `init` | Callable | Function to define the initialization logic. | -| `dims` | Dict[str, Optional[int]] | Dictionary describing the model's dimensions. Map unknown dimensions to `None`. | -| `params` | Dict[str, Optional[FloatsXd]] | Dictionary with the model's parameters. Set currently unavailable parameters to `None`. | -| `refs` | Dict[str, Optional[Model]] | Dictionary mapping specific nodes (sublayers) of the network to a name. | -| `attrs` | Dict[str, Any] | Dictionary of non-parameter attributes. | -| `layers` | List[Model] | List of child layers. | -| `shims` | List[Shim] | List of interfaces for external models. | -| `ops` | Optional[Union[NumpyOps, CupyOps]] | An `Ops` instance, which provides mathematical and memory operations. | +| Argument | Type | Description | +| -------------- | ------------------------------------------------------------- | --------------------------------------------------------------------------------------- | +| `name` | str | The name of the layer type. | +| `forward` | Callable | Function to compute the forward result and the backpropagation callback. | +| _keyword-only_ | | | +| `init` | Callable | Function to define the initialization logic. | +| `dims` | Dict[str, Optional[int]] | Dictionary describing the model's dimensions. Map unknown dimensions to `None`. | +| `params` | Dict[str, Optional[FloatsXd]] | Dictionary with the model's parameters. Set currently unavailable parameters to `None`. | +| `refs` | Dict[str, Optional[Model]] | Dictionary mapping specific nodes (sublayers) of the network to a name. | +| `attrs` | Dict[str, Any] | Dictionary of non-parameter attributes. | +| `layers` | List[Model] | List of child layers. | +| `shims` | List[Shim] | List of interfaces for external models. | +| `ops` | Optional[Union[NumpyOps, AppleOps, CupyOps, MPSOps]] | An `Ops` instance, which provides mathematical and memory operations. | ### Model.define_operators {#define_operators tag="classmethod,contextmanager"} @@ -260,17 +260,17 @@ for node in model.walk(): The `walk` method supports three iteration orders through the `order` argument: -* `"bfs"`: breadth-first. Iteration order of the example above: - *1 - 2 - 4 - 3 - 5* -* `"dfs_pre"`: depth-first preorder, outputs a node before its children. - Iteration order of the example above: *1 - 2 - 3 - 4 - 5* -* `"dfs_post"`: depth-first postorder, outputs children before a node itself. - Iteration order of the example above: *3 - 2 - 5 - 4 - 1* +- `"bfs"`: breadth-first. Iteration order of the example above: _1 - 2 - 4 - 3 - + 5_ +- `"dfs_pre"`: depth-first preorder, outputs a node before its children. + Iteration order of the example above: _1 - 2 - 3 - 4 - 5_ +- `"dfs_post"`: depth-first postorder, outputs children before a node itself. + Iteration order of the example above: _3 - 2 - 5 - 4 - 1_ -| Argument | Type | Description | -|-------------|--------------------------|--------------------------------------------------------------------------------------------------------------------------------------------| +| Argument | Type | Description | +| ----------- | ------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------- | | `order` | str | Node iteration order. `"bfs"` (breadth-first), `"dfs_pre"` (depth-first preorder), `"dfs_post"` (depth-first postorder) Default: `"bfs"`. | -| **RETURNS** | Iterable[Model] | The layers of the model. | +| **RETURNS** | Iterable[Model] | The layers of the model. | ### Model.remove_node {#remove_node tag="method"} @@ -329,9 +329,9 @@ assert model.get_dim("nI") == 16 Retrieve the value of a dimension of the given name, or `None` if the dimension is either unregistered or the value is currently unset. -| Argument | Type | Description | -| ----------- | --------------------- | --------------------------------------- | -| `name` | str | The name of the dimension, e.g. `"nO"`. | +| Argument | Type | Description | +| ----------- | ---------------------- | --------------------------------------- | +| `name` | str | The name of the dimension, e.g. `"nO"`. | | **RETURNS** | Optional[int] | The size of the dimension, or `None`. | ### Model.set_dim {#set_dim tag="method"} diff --git a/website/docs/api-optimizers.md b/website/docs/api-optimizers.md index 47873cc1c..2deab184e 100644 --- a/website/docs/api-optimizers.md +++ b/website/docs/api-optimizers.md @@ -14,10 +14,9 @@ zero the gradients in place. The optimizers are registered in the ### SGD {#sgd tag="function"} -If a hyperparameter specifies a schedule as a list or generator, its value will -be replaced with the next item on each call to -[`Optimizer.step_schedules`](#step-schedules). Once the schedule is exhausted, -its last value will be used. +Function to create a SGD optimizer. If a hyperparameter specifies a schedule, +the step that is passed to the schedule will be incremented on each call to +[`Optimizer.step_schedules`](#step-schedules). @@ -58,10 +57,9 @@ use_averages = true ### Adam {#adam tag="function"} Function to create an Adam optimizer. Returns an instance of -[`Optimizer`](#optimizer). If a hyperparameter specifies a schedule as a list or -generator, its value will be replaced with the next item on each call to -[`Optimizer.step_schedules`](#step-schedules). Once the schedule is exhausted, -its last value will be used. +[`Optimizer`](#optimizer). If a hyperparameter specifies a schedule, the step +that is passed to the schedule will be incremented on each call to +[`Optimizer.step_schedules`](#step-schedules). @@ -113,10 +111,9 @@ use_averages = true ### RAdam {#radam tag="function"} Function to create an RAdam optimizer. Returns an instance of -[`Optimizer`](#optimizer). If a hyperparameter specifies a schedule as a list or -generator, its value will be replaced with the next item on each call to -[`Optimizer.step_schedules`](#step-schedules). Once the schedule is exhausted, -its last value will be used. +[`Optimizer`](#optimizer). If a hyperparameter specifies a schedule, the step +that is passed to the schedule will be incremented on each call to +[`Optimizer.step_schedules`](#step-schedules). @@ -171,10 +168,9 @@ momentum. Currently support "vanilla" SGD, Adam, and RAdam. ### Optimizer.\_\_init\_\_ {#init tag="method"} -Initialize an optimizer. If a hyperparameter specifies a schedule as a list or -generator, its value will be replaced with the next item on each call to -[`Optimizer.step_schedules`](#step-schedules). Once the schedule is exhausted, -its last value will be used. +Initialize an optimizer. If a hyperparameter specifies a schedule, the step that +is passed to the schedule will be incremented on each call to +[`Optimizer.step_schedules`](#step-schedules). ```python ### Example @@ -213,9 +209,8 @@ and parameter name. ### Optimizer.step_schedules {#step_schedules tag="method"} -Replace the the named hyperparameters with the next item from the schedules -iterator, if available. Once the schedule is exhausted, its last value will be -used. +Increase the current step of the optimizer. This step will be used by schedules +to determine their next value. ```python ### Example diff --git a/website/docs/api-schedules.md b/website/docs/api-schedules.md index f15877111..c3837055b 100644 --- a/website/docs/api-schedules.md +++ b/website/docs/api-schedules.md @@ -5,11 +5,115 @@ next: /docs/api-loss Schedules are generators that provide different rates, schedules, decays or series. They're typically used for batch sizes or learning rates. You can easily -implement your own schedules as well: just write your own generator function, -that produces whatever series of values you need. A common use case for -schedules is within [`Optimizer`](/docs/api-optimizer) objects, which accept -iterators for most of their parameters. See the -[training guide](/docs/usage-training) for details. +implement your own schedules as well: just write your own +[`Schedule`](#schedule) implementation, that produces whatever series of values +you need. A common use case for schedules is within +[`Optimizer`](/docs/api-optimizer) objects, which accept iterators for most of +their parameters. See the [training guide](/docs/usage-training) for details. + +## Schedule {#schedule tag="class" new="9"} + +Class for implementing Thinc schedules. + + + +There's only one `Schedule` class in Thinc and schedules are built using +**composition**, not inheritance. This means that a schedule or composed +schedule will return an **instance** of `Schedule` – it doesn't subclass it. To +read more about this concept, see the pages on +[Thinc's philosophy](/docs/concept). + + + +### Typing {#typing} + +`Schedule` can be used as a +[generic type](https://docs.python.org/3/library/typing.html#generics) with one +parameter. This parameter specifies the type that is returned by the schedule. +For instance, `Schedule[int]` denotes a scheduler that returns integers when +called. A mismatch will cause a type error. For more details, see the docs on +[type checking](/docs/usage-type-checking). + +```python +from thinc.api import Schedule + +def my_function(schedule: Schedule[int]): + ... +``` + +### Attributes {#attributes} + +| Name | Type | Description | +| ------ | ------------ | ------------------------------- | +| `name` | str | The name of the scheduler type. | + +### Properties {#properties} + +| Name | Type | Description | +| ------- | ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `attrs` | Dict[str, Any] | The scheduler attributes. You can use the dict directly and assign _to_ it – but you cannot reassign `schedule.attrs` to a new variable: `schedule.attrs = {}` will fail. | + +### Schedule.\_\_init\_\_ {#init tag="method"} + +Initialize a new schedule. + +```python +### Example +schedule = Schedule( + "constant", + constant_schedule, + attrs={"rate": rate}, +) +``` + +| Argument | Type | Description | +| -------------- | ----------------------- | -------------------------------------------------------- | +| `name` | str | The name of the schedule type. | +| `schedule` | Callable | Function to compute the schedule value for a given step. | +| _keyword-only_ | | | +| `attrs` | Dict[str, Any] | Dictionary of non-parameter attributes. | + +### Schedule.\_\_call\_\_ {#call tag="method"} + +Call the schedule function, returning the value for the given step. The `step` +positional argument is always required. Some schedules may require additional +keyword arguments. + +```python +### Example +from thinc.api import constant + +schedule = constant(0.1) +assert schedule(0) == 0.1 +assert schedule(1000) == 0.1 +``` + +| Argument | Type | Description | +| ----------- | ------------ | ------------------------------------------ | +| `step` | int | The step to compute the schedule for. | +| `**kwargs` | | Optional arguments passed to the schedule. | +| **RETURNS** | Any | The schedule value for the step. | + +### Schedule.to_generator {#to_generator tag="method"} + +Turn the schedule into a generator by passing monotonically increasing step +count into the schedule. + +```python +### Example +from thinc.api import constant + +g = constant(0.1).to_generator() +assert next(g) == 0.1 +assert next(g) == 0.1 +``` + +| Argument | Type | Description | +| ----------- | ------------------------------------ | ------------------------------------------------------------------------------- | +| `start` | int | The initial schedule step. Defaults to `0`. | +| `step_size` | int | The amount to increase the step with for each generated value. Defaults to `1`. | +| `**kwargs` | | Optional arguments passed to the schedule. | +| **RETURNS** | Generator[OutT, None, None] | The generator. | ## constant {#constant tag="function"} @@ -24,7 +128,7 @@ Yield a constant rate. from thinc.api import constant batch_sizes = constant(0.001) -batch_size = next(batch_sizes) +batch_size = batch_sizes(step=0) ``` ```ini @@ -58,7 +162,7 @@ learn_rates = constant_then( 1000, decaying(0.005, 1e-4) ) -learn_rate = next(learn_rates) +learn_rate = learn_rates(step=0) ``` ```ini @@ -97,8 +201,8 @@ Yield an infinite series of linearly decaying values, following the schedule from thinc.api import decaying learn_rates = decaying(0.005, 1e-4) -learn_rate = next(learn_rates) # 0.001 -learn_rate = next(learn_rates) # 0.00999 +learn_rate = learn_rates(step=0) # 0.001 +learn_rate = learn_rates(step=1) # 0.00999 ``` ```ini @@ -135,8 +239,8 @@ rate. from thinc.api import compounding batch_sizes = compounding(1.0, 32.0, 1.001) -batch_size = next(batch_sizes) # 1.0 -batch_size = next(batch_sizes) # 1.0 * 1.001 +batch_size = batch_sizes(step=0) # 1.0 +batch_size = batch_sizes(step=1) # 1.0 * 1.001 ``` ```ini @@ -174,7 +278,7 @@ and then a linear decline. Used for learning rates. from thinc.api import warmup_linear learn_rates = warmup_linear(0.01, 3000, 6000) -learn_rate = next(learn_rates) +learn_rate = learn_rates(step=0) ``` ```ini @@ -210,7 +314,7 @@ triangular learning rate" schedule. from thinc.api import slanted_triangular learn_rates = slanted_triangular(0.1, 5000) -learn_rate = next(learn_rates) +learn_rate = learn_rates(step=0) ``` ```ini @@ -251,7 +355,7 @@ Linearly increasing then linearly decreasing the rate at each cycle. from thinc.api import cyclic_triangular learn_rates = cyclic_triangular(0.005, 0.001, 1000) -learn_rate = next(learn_rates) +learn_rate = learn_rates(step=0) ``` ```ini @@ -271,3 +375,47 @@ period = 1000 | `max_lr` | float | | `period` | int | | **YIELDS** | float | + +## plateau {#plateau tag="function" new="9"} + +Yields values from the wrapped schedule, exponentially scaled by the number of +times optimization has plateaued. The caller must pass model evaluation scores +through the `last_score` argument for the scaling to be adjusted. The last +evaluation score is passed through the `last_score` argument as a tuple +(`last_score_step`, `last_score`). This tuple indicates when a model was last +evaluated (`last_score_step`) and with what score (`last_score`). + + + +```python +### {small="true"} +from thinc.api import constant, plateau + +schedule = plateau(2, 0.5, constant(1.0)) +assert schedule(step=0, last_score=(0, 1.0)) == 1.0 +assert schedule(step=1, last_score=(1, 1.0)) == 1.0 +assert schedule(step=2, last_score=(2, 1.0)) == 0.5 +assert schedule(step=3, last_score=(3, 1.0)) == 0.5 +assert schedule(step=4, last_score=(4, 1.0)) == 0.25 +``` + +```ini +### config {small="true"} +[learn_rate] +@schedules = "plateau.v1" +scale = 0.5 +max_patience = 2 + +[learn_rate.shedule] +@schedules = "constant.v1" +rate = 1.0 +``` + + + +| Argument | Type | Description | +| -------------- | ------------------------ | ------------------------------------------------------------------------------------- | ----------------------------------------------- | +| `max_patience` | int | Number of evaluations without an improvement to consider the model to have plateaued. | +| `scale` | float | | Scaling of the inner schedule after plateauing. | +| `schedule` | Schedule[float] | | The schedule to wrap. | +| **RETURNS** | Schedule[float] | | diff --git a/website/docs/usage-config.md b/website/docs/usage-config.md index 73a1638ac..2887c39d5 100644 --- a/website/docs/usage-config.md +++ b/website/docs/usage-config.md @@ -190,21 +190,30 @@ For details and examples, see the The function registry integration becomes even more powerful when used to build **recursive structures**. Let's say you want to use a learning rate schedule and -pass in a generator as the `learn_rate` argument. Here's an example of a -function that yields an infinite series of decaying values, following the -schedule `base_rate * 1 / (1 + decay * t)`. It's also available in Thinc as +pass in a schedule as the `learn_rate` argument. Here's an example of a function +that yields an infinite series of decaying values, following the schedule +`base_rate * 1 / (1 + decay * t)`. It's also available in Thinc as [`schedules.decaying`](/docs/api-schedules#decaying). The decorator registers the function `"my_cool_decaying_schedule.v1"` in the registry `schedules`: ```python -from typing import Iterable import thinc +from thinc.schedules import Schedule @thinc.registry.schedules("my_cool_decaying_schedule.v1") -def decaying(base_rate: float, decay: float, *, t: int = 0) -> Iterable[float]: - while True: - yield base_rate * (1.0 / (1.0 + decay * t)) - t += 1 +def decaying(base_rate: float, decay: float, *, t: int = 0) -> Schedule[float]: + return Schedule( + "decaying", + _decaying_schedule, + attrs={"base_rate": base_rate, "decay": decay, "t": t} + ) + + +def _decaying_schedule(schedule: Schedule, step: int, **kwargs) -> float: + base_rate = schedule.attrs["base_rate"] + decay = schedule.attrs["decay"] + t = schedule.attrs["t"] + return base_rate * (1.0 / (1.0 + decay * (step + t))) ``` In your config, you can now define the `learn_rate` as a subsection of @@ -230,15 +239,6 @@ argument. If type annotations are available for the return value and it's a type that can be evaluated, the return value of the function will be validated as well. - - -**A note on validating generators:** If a value is a generator, it won't be -validated further, since this would mean having to execute and consume it. -Generators can potentially be infinite – like the decaying schedule in this -example – so checking its return value isn't viable. - - - ```python ### Under the hood learn_rate_func = thinc.registry.get("schedules", "my_cool_decaying_schedule.v1") @@ -290,11 +290,22 @@ values: ```python ### {small="true"} +import thinc +from thinc.schedules import Schedule + @thinc.registry.schedules("my_cool_schedule.v1") -def schedule(*steps: float, final: float = 1.0) -> Iterable[float]: - yield from steps - while True: - yield final +def step_values(*steps: float, final: float = 1.0) -> Schedule[float]: + step_list = list(steps) + return Schedule( + "step_values", + _step_values_schedule, + attrs={"steps": list(steps), "final": final} + ) + +def _step_values_schedule(schedule: Schedule, step: int, **kwargs) -> float: + steps = schedule.attrs["steps"] + final = schedule.attrs["final"] + return steps[step] if step < len(steps) else final ``` ```ini diff --git a/website/docs/usage-training.md b/website/docs/usage-training.md index c34648b89..8df7127a3 100644 --- a/website/docs/usage-training.md +++ b/website/docs/usage-training.md @@ -120,10 +120,9 @@ also simply consume the entire generator, by calling `list()` on it. Finally, `minibatch` and `multibatch` support **variable length batching**, based on a schedule you can provide as the `batch_size` argument. Simply pass in -an iterable (such as a generator from the -[built-in schedules](/docs/api-schedules)) instead of an integer. Variable -length batching is non-standard, but we regularly use it for some of -[spaCy](https://spacy.io)'s models, especially the parser and entity recognizer. +an iterable. Variable length batching is non-standard, but we regularly use it +for some of [spaCy](https://spacy.io)'s models, especially the parser and entity +recognizer. ```python from thinc.api import compounding @@ -225,37 +224,39 @@ normalize = true A common trick for stochastic gradient descent is to **vary the learning rate or other hyperparameters** over the course of training. Since there are many possible ways to vary the learning rate, Thinc lets you implement hyperparameter -schedules as simple generator functions. Thinc also provides a number of -[popular schedules](/docs/api-schedules) built-in. - -You can use schedules directly, by calling `next()` on the schedule and using it -to update hyperparameters in your training loop. Since schedules are -particularly common for optimization settings, the -[`Optimizer`](/docs/api-optimizer) object accepts floats, lists and iterators -for most of its parameters. When you call -[`Optimizer.step_schedules`](/docs/api-optimizer#step_schedules), the optimizer -will draw the next value from the generators and use them to change the given -attributes. For instance, here's how to create an instance of the `Adam` -optimizer with a custom learning rate schedule: +schedules as instances of the [`Schedule`](/docs/api-schedules#schedule) class. +Thinc also provides a number of [popular schedules](/docs/api-schedules) +built-in. + +You can use schedules directly, by calling the schedule with the `step` keyword +argument and using it to update hyperparameters in your training loop. Since +schedules are particularly common for optimization settings, the +[`Optimizer`](/docs/api-optimizer) object accepts floats, lists, iterators, and +[`Schedule`](/docs/api-schedules#schedule) instances for most of its parameters. +When you call [`Optimizer.step_schedules`](/docs/api-optimizer#step_schedules), +the optimizer will increase its step count and pass it to the schedules. For +instance, this is how one creates an instance of the `Adam` optimizer with a +custom learning rate schedule: ```python ### Custom learning rate schedule -from thinc.api import Adam +from thinc.api import Adam, Schedule -def my_schedule(): +def cycle(): values = [0.001, 0.01, 0.1] - while True: - for value in values: - yield value - for value in reversed(values): - yield value - -optimizer = Adam(learn_rate=my_schedule()) -assert optimizer.learn_rate == 0.001 + all_values = values + list(reversed(values)) + return Schedule("cycle", _cycle_schedule, attrs={"all_values": all_values}) + +def _cycle_schedule(schedule: Schedule, step: int, **kwargs) -> float: + all_values = schedule.attrs["all_values"] + return all_values[step % len(all_values)] + +optimizer = Adam(learn_rate=cycle()) +assert optimizer.learn_rate(optimizer.step) == 0.001 optimizer.step_schedules() -assert optimizer.learn_rate == 0.01 +assert optimizer.learn_rate(optimizer.step) == 0.01 optimizer.step_schedules() -assert optimizer.learn_rate == 0.1 +assert optimizer.learn_rate(optimizer.step) == 0.1 ``` ![](images/schedules_custom1.svg) @@ -271,13 +272,14 @@ of the optimizer. Check out the ```python ### Registered function {small="true"} -@thinc.registry.schedules("my_schedule.v1") -def my_schedule(values): - while True: - for value in values: - yield value - for value in reversed(values): - yield value +@thinc.registry.schedules("cycle.v1") +def cycle(values): + all_values = values + list(reversed(values)) + return Schedule("cycle", _cycle_schedule, attrs={"all_values": all_values}) + +def _cycle_schedule(schedule: Schedule, step: int, **kwargs) -> float: + all_values = schedule.attrs["all_values"] + return all_values[step % len(all_values)] ``` ```ini @@ -286,7 +288,7 @@ def my_schedule(values): @optimizers = "Adam.v1" [optimizer.learn_rate] -@schedules = "my_schedule.v1" +@schedules = "cycle.v1" values = [0.001, 0.01, 0.1] ```