Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge v9 into thinc.ai #931

Merged
merged 53 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
64967eb
Remove `thinc.extra.search` module and related tests (moved to spaCy)…
shadeMe Aug 17, 2022
b0c9be8
Merge branch 'master' into chore/merge-master-into-v9
shadeMe Sep 9, 2022
a80d275
Merge pull request #764 from shadeMe/chore/merge-master-into-v9
svlandeg Sep 9, 2022
43ef766
`NumpyOps` cleanup (#760)
shadeMe Sep 13, 2022
17c823e
disable mypy run for Python 3.10 (#768) (#769)
svlandeg Sep 15, 2022
0366934
Remove vestigial/mostly unused `backends.linalg` module (#742)
shadeMe Sep 16, 2022
de40bdf
Standardize `blis` calls in `NumpyOps` (#763)
shadeMe Sep 16, 2022
372ecf5
Merge branch 'master' into update/v9
svlandeg Oct 10, 2022
20d97d9
Merge branch 'master' into update/v9
adrianeboyd Oct 26, 2022
166f39b
Merge branch 'master' into update/v9
adrianeboyd Oct 27, 2022
b979a94
Merge branch 'master' into update/v9
adrianeboyd Oct 27, 2022
9e3acb8
Merge pull request #784 from svlandeg/update/v9
adrianeboyd Oct 27, 2022
c8ac07f
Cross entropy fix (#647)
kadarakos Oct 28, 2022
9039d67
Merge remote-tracking branch 'upstream/master' into sync/v9-master-20…
danieldk Dec 5, 2022
c5ad06d
Merge pull request #810 from danieldk/sync/v9-master-20221205
svlandeg Dec 5, 2022
cdc9717
Bring back support for missing labels to legacy cross entropy (#809)
danieldk Dec 9, 2022
9743709
Set version to v9.0.0.dev0 (#816)
danieldk Dec 9, 2022
07f8f88
Fix spurious `v` prefix in the version number (#818)
danieldk Dec 9, 2022
4645083
Merge remote-tracking branch 'upstream/master' into chores/merge-mast…
danieldk Dec 20, 2022
6e91494
Merge pull request #826 from danieldk/chores/merge-master-20221220
danieldk Dec 20, 2022
717c70e
Give schedules access to the key, step, and last eval score (#804)
danieldk Dec 22, 2022
f6f6c81
Set version to v9.0.0.dev1 (#829)
danieldk Dec 23, 2022
7f35b3c
Add `Schedule.to_generator` (#837)
danieldk Jan 12, 2023
9704ad9
Merge remote-tracking branch 'upstream/master' into chore/v9-merge-ma…
danieldk Jan 12, 2023
ece7eec
Merge pull request #840 from danieldk/chore/v9-merge-master-20230112
danieldk Jan 12, 2023
bbe8f53
Set version to v9.0.0.dev2
danieldk Jan 12, 2023
e58ca10
Merge pull request #841 from danieldk/chore/bump-v9.0.0.dev2
danieldk Jan 12, 2023
f576d1e
Add plateau.v1 schedule (#842)
danieldk Jan 19, 2023
fc24e8a
Smooth one hot fix (#830)
kadarakos Feb 1, 2023
82c9151
Merge remote-tracking branch 'upstream/master' into v9
danieldk Mar 22, 2023
ab9c439
Merge pull request #867 from danieldk/chore/update-v9-master-20230322
danieldk Mar 22, 2023
bf0e276
Set version to v9.0.0.dev3 (#868)
danieldk Mar 22, 2023
46ceec5
Merge remote-tracking branch 'upstream/master' into chore/update-v9-m…
adrianeboyd Apr 27, 2023
886231f
Merge pull request #874 from adrianeboyd/chore/update-v9-master-1
adrianeboyd Apr 28, 2023
816ea33
Temporarily revert new loss implementations (#916)
danieldk Jan 8, 2024
95f894f
isort
danieldk Jan 9, 2024
e570a1a
Merge remote-tracking branch 'upstream/master' into maintenance/v9-me…
danieldk Jan 9, 2024
d34f536
strings2arrays: make work again for sequences of inequal length
danieldk Jan 9, 2024
5c46b82
Fix local thread storage usage and make it typecheck
danieldk Jan 9, 2024
09e9555
Fixup imports that lead to type checking issues
danieldk Jan 9, 2024
6c314d2
Fix strings2array (#918)
svlandeg Jan 16, 2024
bf2e00b
Merge pull request #917 from danieldk/maintenance/v9-merge-master-202…
danieldk Jan 16, 2024
40d4148
Set version to v9.0.0.dev4 (#919)
danieldk Jan 16, 2024
307a4f8
Fix `cupy.cublas` import (#921)
danieldk Feb 7, 2024
3aae298
Set version to v8.2.3 (#922)
danieldk Feb 7, 2024
c1c72f2
Merge remote-tracking branch 'upstream/master' into maintenance/merge…
danieldk Feb 12, 2024
dec431c
Merge pull request #923 from danieldk/maintenance/merge-master-20240212
svlandeg Feb 23, 2024
ec68d7d
Set version to 9.0.0.dev5 (#925)
danieldk Apr 8, 2024
c998bf2
Merge `thinc-apple-ops` into Thinc (#927)
danieldk Apr 16, 2024
2a0b9c1
Set version to 9.0.0.dev6 (#928)
danieldk Apr 17, 2024
ccae258
Document `AppleOps` and `MPSOps` (#929)
danieldk Apr 18, 2024
5be631e
Set version to 9.0.0 (#930)
danieldk Apr 18, 2024
f348090
Merge branch 'v9' into thinc.ai
danieldk Apr 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ jobs:

- name: Run mypy
run: python -m mypy thinc --no-implicit-reexport
if: matrix.python_version != '3.6'
if: |
matrix.python_version != '3.6' &&
matrix.python_version != '3.7'

- name: Delete source directory
run: rm -rf thinc
Expand Down Expand Up @@ -150,14 +152,3 @@ jobs:

- name: Run tests with extras
run: python -m pytest --pyargs thinc --cov=thinc --cov-report=term -p thinc.tests.enable_tensorflow -p thinc.tests.enable_mxnet

- name: Run tests for thinc-apple-ops
run: |
pip uninstall -y tensorflow
pip install thinc-apple-ops
python -m pytest --pyargs thinc_apple_ops
if: matrix.os == 'macos-latest' && matrix.python_version == '3.10'

- name: Run tests with thinc-apple-ops
run: python -m pytest --pyargs thinc
if: matrix.os == 'macos-latest' && matrix.python_version == '3.10'
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pytest-cov>=2.7.0,<5.0.0
coverage>=5.0.0,<8.0.0
mock>=2.0.0,<3.0.0
flake8>=3.5.0,<3.6.0
mypy>=1.0.0,<1.1.0; python_version >= "3.7"
mypy>=1.5.0,<1.6.0; platform_machine != "aarch64" and python_version >= "3.8"
types-mock>=0.1.1
types-contextvars>=0.1.2; python_version < "3.7"
types-dataclasses>=0.1.3; python_version < "3.7"
Expand Down
18 changes: 14 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
import platform
import sys
from setuptools.command.build_ext import build_ext
from sysconfig import get_path
Expand All @@ -13,16 +14,16 @@
# http://docs.cython.org/en/latest/src/userguide/source_files_and_compilation.html#compiler-options
Options.docstrings = True

ACCELERATE = "thinc.backends._accelerate"
APPLE_OPS = ["thinc.backends.apple_ops", ACCELERATE]

PACKAGES = find_packages()
MOD_NAMES = [
"thinc.backends.cblas",
"thinc.backends.linalg",
"thinc.backends.numpy_ops",
"thinc.extra.search",
"thinc.layers.sparselinear",
"thinc.layers.premap_ids",
]
] + (APPLE_OPS if platform.system() == "Darwin" else [])
COMPILE_OPTIONS = {
"msvc": ["/Ox", "/EHsc"],
"other": ["-O3", "-Wno-strict-prototypes", "-Wno-unused-function", "-std=c++11"],
Expand Down Expand Up @@ -80,7 +81,16 @@ def setup_package():
ext_modules = []
for name in MOD_NAMES:
mod_path = name.replace(".", "/") + ".pyx"
ext = Extension(name, [mod_path], language="c++", include_dirs=include_dirs)
if name == ACCELERATE:
ext = Extension(
name,
[mod_path],
language="c++",
include_dirs=include_dirs,
libraries=["blas"],
)
else:
ext = Extension(name, [mod_path], language="c++", include_dirs=include_dirs)
ext_modules.append(ext)
print("Cythonizing sources")
ext_modules = cythonize(
Expand Down
2 changes: 1 addition & 1 deletion thinc/about.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "8.2.2"
__version__ = "9.0.0"
__release__ = True
13 changes: 10 additions & 3 deletions thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,13 @@
)
from .optimizers import SGD, Adam, Optimizer, RAdam
from .schedules import (
Schedule,
compounding,
constant,
constant_then,
cyclic_triangular,
decaying,
plateau,
slanted_triangular,
warmup_linear,
)
Expand Down Expand Up @@ -160,6 +162,11 @@
xp2torch,
)

try:
from .backends import AppleOps
except ImportError:
AppleOps = None

# fmt: off
__all__ = [
# .config
Expand All @@ -179,8 +186,8 @@
# .optimizers
"Adam", "RAdam", "SGD", "Optimizer",
# .schedules
"cyclic_triangular", "warmup_linear", "constant", "constant_then",
"decaying", "slanted_triangular", "compounding",
"Schedule", "cyclic_triangular", "warmup_linear", "constant", "constant_then",
"decaying", "slanted_triangular", "compounding", "plateau",
# .types
"Ragged", "Padded", "ArgsKwargs", "Unserializable",
# .util
Expand All @@ -196,7 +203,7 @@
"has_cupy",
# .backends
"get_ops", "set_current_ops", "get_current_ops", "use_ops",
"Ops", "CupyOps", "MPSOps", "NumpyOps", "set_gpu_allocator",
"Ops", "AppleOps", "CupyOps", "MPSOps", "NumpyOps", "set_gpu_allocator",
"use_pytorch_for_gpu_memory", "use_tensorflow_for_gpu_memory",
# .layers
"Dropout", "Embed", "expand_window", "HashEmbed", "LayerNorm", "Linear",
Expand Down
33 changes: 15 additions & 18 deletions thinc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,21 @@
from .numpy_ops import NumpyOps
from .ops import Ops

try:
from .apple_ops import AppleOps
except ImportError:
AppleOps = None

context_ops: ContextVar[Optional[Ops]] = ContextVar("context_ops", default=None)
context_pools: ContextVar[dict] = ContextVar("context_pools", default={})

# Internal use of thread-local storage only for detecting cases where a Jupyter
# notebook might not have preserved contextvars across cells.
_GLOBAL_STATE = {"ops": None}

# Thread-local state.
_LOCAL_STATE = threading.local()


def set_gpu_allocator(allocator: str) -> None: # pragma: no cover
"""Route GPU memory allocation via PyTorch or tensorflow.
Expand Down Expand Up @@ -80,10 +88,6 @@ def use_tensorflow_for_gpu_memory() -> None: # pragma: no cover


def _import_extra_cpu_backends():
try:
from thinc_apple_ops import AppleOps
except ImportError:
pass
try:
from thinc_bigendian_ops import BigEndianOps
except ImportError:
Expand Down Expand Up @@ -152,22 +156,14 @@ def contextvars_eq_thread_ops() -> bool:
return False


def _get_thread_state():
def _get_thread_state() -> threading.local:
"""Get a thread-specific state variable that inherits from a global
state when it's created."""
thread: threading.Thread = threading.current_thread()
if not hasattr(thread, "__local"):
thread.__local = _create_thread_local(_GLOBAL_STATE)
return thread.__local


def _create_thread_local(
attrs: Dict[str, Any], local_class: Type[threading.local] = threading.local
):
obj = local_class()
for name, value in attrs.items():
setattr(obj, name, value)
return obj
if not hasattr(_LOCAL_STATE, "initialized") or not _LOCAL_STATE.initialized:
for name, value in _GLOBAL_STATE.items():
setattr(_LOCAL_STATE, name, value)
_LOCAL_STATE.initialized = True
return _LOCAL_STATE


__all__ = [
Expand All @@ -176,6 +172,7 @@ def _create_thread_local(
"use_ops",
"ParamServer",
"Ops",
"AppleOps",
"CupyOps",
"MPSOps",
"NumpyOps",
Expand Down
40 changes: 40 additions & 0 deletions thinc/backends/_accelerate.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
cdef extern from "Accelerate/Accelerate.h":
enum CBLAS_ORDER: CblasRowMajor, CblasColMajor
enum CBLAS_TRANSPOSE: CblasNoTrans, CblasTrans, CblasConjTrans
enum CBLAS_UPLO: CblasUpper, CblasLower
enum CBLAS_DIAG: CblasNonUnit, CblasUnit
enum CBLAS_SIDE: CblasLeft, CblasRight

# BLAS level 1 routines

void cblas_sswap(int M, float *x, int incX, float *y, int incY) nogil
void cblas_sscal(int N, float alpha, float *x, int incX) nogil
void cblas_scopy(int N, float *x, int incX, float *y, int incY) nogil
void cblas_saxpy(int N, float alpha, float *x, int incX, float *y, int incY ) nogil
float cblas_sdot(int N, float *x, int incX, float *y, int incY ) nogil
float cblas_snrm2(int N, float *x, int incX) nogil
float cblas_sasum(int N, float *x, int incX) nogil
int cblas_isamax(int N, float *x, int incX) nogil

# BLAS level 2 routines
void cblas_sgemv(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, int M, int N,
float alpha, float *A, int lda, float *x, int incX,
float beta, float *y, int incY) nogil

void cblas_sger(CBLAS_ORDER Order, int M, int N, float alpha, float *x,
int incX, float *y, int incY, float *A, int lda) nogil

# BLAS level 3 routines
void cblas_sgemm(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB, int M, int N, int K,
float alpha, float *A, int lda, float *B, int ldb,
float beta, float *C, int ldc) nogil


cdef void sgemm(bint TransA, bint TransB, int M, int N, int K,
float alpha, const float* A, int lda, const float *B,
int ldb, float beta, float* C, int ldc) nogil


cdef void saxpy(int N, float alpha, const float* X, int incX,
float *Y, int incY) nogil
75 changes: 75 additions & 0 deletions thinc/backends/_accelerate.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
cimport numpy as np
from libc.stdint cimport uintptr_t

import numpy


cpdef np.ndarray gemm(float[:, ::1] A, float[:, ::1] B,
bint trans1=False, bint trans2=False,
np.ndarray out=None):
cdef int nM = A.shape[0] if not trans1 else A.shape[1]
cdef int nK = A.shape[1] if not trans1 else A.shape[0]
cdef int nK_b = B.shape[0] if not trans2 else B.shape[1]
cdef int nN = B.shape[1] if not trans2 else B.shape[0]

cdef float[:, ::1] C = out

if out is None:
out = numpy.empty((nM, nN), dtype="f")
C = out
else:
if C.shape[0] != nM or C.shape[1] != nN:
msg = "Shape mismatch for output matrix, was: (%d, %d), expected (%d, %d)"
raise ValueError(msg % (C.shape[0], C.shape[1], nM, nN))


if nK != nK_b:
msg = "Shape mismatch for gemm: (%d, %d), (%d, %d)"
raise ValueError(msg % (nM, nK, nK_b, nN))

if nM == 0 or nK == 0 or nN == 0:
return out

cblas_sgemm(
CblasRowMajor,
CblasTrans if trans1 else CblasNoTrans,
CblasTrans if trans2 else CblasNoTrans,
nM,
nN,
nK,
1.0,
&A[0, 0],
A.shape[1],
&B[0, 0],
B.shape[1],
0.0,
&C[0, 0],
C.shape[1]
)
return out


cdef void sgemm(bint TransA, bint TransB, int M, int N, int K,
float alpha, const float* A, int lda, const float *B,
int ldb, float beta, float* C, int ldc) nogil:
cblas_sgemm(
CblasRowMajor,
CblasTrans if TransA else CblasNoTrans,
CblasTrans if TransB else CblasNoTrans,
M,
N,
K,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc
)


cdef void saxpy(int N, float alpha, const float* X, int incX,
float *Y, int incY) nogil:
cblas_saxpy(N, alpha, X, incX, Y, incY)
39 changes: 39 additions & 0 deletions thinc/backends/apple_ops.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Optional

import numpy

from ._accelerate import gemm

from ._accelerate cimport saxpy, sgemm
from .cblas cimport CBlas, set_saxpy, set_sgemm

from .. import registry
from ..types import Floats2d
from .numpy_ops import NumpyOps


@registry.ops("AppleOps")
class AppleOps(NumpyOps):
"""Thinc Ops class that calls into Apple's native libraries for some
operations. Other operations fall back to numpy."""
name = "apple"
xp = numpy

def cblas(self) -> CBlas:
cdef CBlas cblas = CBlas()
set_saxpy(cblas, saxpy)
set_sgemm(cblas, sgemm)
return cblas

def gemm(
self,
x: Floats2d,
y: Floats2d,
out: Optional[Floats2d] = None,
trans1: bool = False,
trans2: bool = False,
) -> Floats2d:
"""Perform General Matrix Multiplication (GeMM) and optionally store
the result in the specified output variable.
"""
return gemm(x, y, out=out, trans1=trans1, trans2=trans2)
13 changes: 12 additions & 1 deletion thinc/backends/cblas.pxd
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from libcpp.memory cimport shared_ptr

ctypedef void (*sgemm_ptr)(bint transA, bint transB, int M, int N, int K,
float alpha, const float* A, int lda, const float *B,
float alpha, const float* A, int lda, const float* B,
int ldb, float beta, float* C, int ldc) nogil
ctypedef void (*dgemm_ptr)(bint transA, bint transB, int M, int N, int K,
double alpha, const double* A, int lda, const double* B,
int ldb, double beta, double* C, int ldc) nogil


ctypedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX,
Expand All @@ -12,6 +15,8 @@ ctypedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX,
ctypedef void (*daxpy_ptr)(int N, double alpha, const double* X, int incX,
double *Y, int incY) nogil

ctypedef void (*sscal_ptr)(int N, float alpha, float* X, int incX) nogil
ctypedef void (*dscal_ptr)(int N, double alpha, double* X, int incX) nogil

# Forward-declaration of the BlasFuncs struct. This struct must be opaque, so
# that consumers of the CBlas class cannot become dependent on its size or
Expand All @@ -32,6 +37,12 @@ cdef class CBlas:
cdef daxpy_ptr daxpy(CBlas cblas) nogil
cdef saxpy_ptr saxpy(CBlas cblas) nogil
cdef sgemm_ptr sgemm(CBlas cblas) nogil
cdef dgemm_ptr dgemm(CBlas cblas) nogil
cdef sscal_ptr sscal(CBlas cblas) nogil
cdef dscal_ptr dscal(CBlas cblas) nogil
cdef void set_daxpy(CBlas cblas, daxpy_ptr daxpy) nogil
cdef void set_saxpy(CBlas cblas, saxpy_ptr saxpy) nogil
cdef void set_sgemm(CBlas cblas, sgemm_ptr sgemm) nogil
cdef void set_dgemm(CBlas cblas, dgemm_ptr dgemm) nogil
cdef void set_sscal(CBlas cblas, sscal_ptr sscal) nogil
cdef void set_dscal(CBlas cblas, dscal_ptr dscal) nogil
Loading
Loading