diff --git a/curated_transformers/__init__.py b/curated_transformers/__init__.py index e69de29b..0fdf4836 100644 --- a/curated_transformers/__init__.py +++ b/curated_transformers/__init__.py @@ -0,0 +1,19 @@ +import catalogue +from catalogue import Registry + + +class registry(object): + """ + Registry for models. These registries are used by auto classes to + discover the available models. + """ + + causal_lms: Registry = catalogue.create( + "curated_transformers", "causal_lms", entry_points=True + ) + decoders: Registry = catalogue.create( + "curated_transformers", "decoders", entry_points=True + ) + encoders: Registry = catalogue.create( + "curated_transformers", "encoders", entry_points=True + ) diff --git a/curated_transformers/models/albert/encoder.py b/curated_transformers/models/albert/encoder.py index 1a247c23..2c4ab449 100644 --- a/curated_transformers/models/albert/encoder.py +++ b/curated_transformers/models/albert/encoder.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -99,6 +99,10 @@ def forward( return ModelOutput(all_outputs=[embeddings, *layer_outputs]) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("albert",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/auto_model.py b/curated_transformers/models/auto_model.py index 94de95a0..e1523a30 100644 --- a/curated_transformers/models/auto_model.py +++ b/curated_transformers/models/auto_model.py @@ -1,27 +1,21 @@ +import warnings from abc import ABC, abstractmethod -from typing import Dict, Generic, Optional, Type, TypeVar +from typing import Generic, Optional, Type, TypeVar import torch +from catalogue import Registry from fsspec import AbstractFileSystem +from curated_transformers import registry + from ..layers.cache import KeyValueCache from ..quantization.bnb.config import BitsAndBytesConfig from ..repository.fsspec import FsspecArgs, FsspecRepository from ..repository.hf_hub import HfHubRepository from ..repository.repository import ModelRepository, Repository -from .albert import ALBERTEncoder -from .bert import BERTEncoder -from .camembert import CamemBERTEncoder from .config import TransformerConfig -from .falcon import FalconCausalLM, FalconDecoder -from .gpt_neox import GPTNeoXCausalLM, GPTNeoXDecoder from .hf_hub import FromHFHub -from .llama import LlamaCausalLM, LlamaDecoder from .module import CausalLMModule, DecoderModule, EncoderModule -from .mpt.causal_lm import MPTCausalLM -from .mpt.decoder import MPTDecoder -from .roberta import RoBERTaEncoder -from .xlm_roberta import XLMREncoder ModelT = TypeVar("ModelT") @@ -32,7 +26,7 @@ class AutoModel(ABC, Generic[ModelT]): Face Model Hub. """ - _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {} + _registry: Registry @classmethod def _resolve_model_cls( @@ -40,14 +34,25 @@ def _resolve_model_cls( repo: ModelRepository, ) -> Type[FromHFHub]: model_type = repo.model_type() - module_cls = cls._hf_model_type_to_curated.get(model_type) - if module_cls is None: - raise ValueError( - f"Unsupported model type `{model_type}` for {cls.__name__}. " - f"Supported model types: {tuple(cls._hf_model_type_to_curated.keys())}" - ) - assert issubclass(module_cls, FromHFHub) - return module_cls + + supported_model_types = set() + for entrypoint, module_cls in cls._registry.get_entry_points().items(): + if not issubclass(module_cls, FromHFHub): + warnings.warn( + f"Entry point {entrypoint} cannot load from Huggingface Hub " + "since the FromHFHub mixin is not implemented" + ) + continue + + module_model_types = module_cls.hf_model_types() + if model_type in module_model_types: + return module_cls + supported_model_types.update(module_model_types) + + raise ValueError( + f"Unsupported model type `{model_type}` for {cls.__name__}. " + f"Supported model types: {', '.join(sorted(supported_model_types))}" + ) @classmethod def _instantiate_model( @@ -182,13 +187,7 @@ class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]): Encoder model loaded from the Hugging Face Model Hub. """ - _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = { - "bert": BERTEncoder, - "albert": ALBERTEncoder, - "camembert": CamemBERTEncoder, - "roberta": RoBERTaEncoder, - "xlm-roberta": XLMREncoder, - } + _registry: Registry = registry.encoders @classmethod def from_repo( @@ -208,14 +207,7 @@ class AutoDecoder(AutoModel[DecoderModule[TransformerConfig, KeyValueCache]]): Decoder module loaded from the Hugging Face Model Hub. """ - _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = { - "falcon": FalconDecoder, - "gpt_neox": GPTNeoXDecoder, - "llama": LlamaDecoder, - "mpt": MPTDecoder, - "RefinedWeb": FalconDecoder, - "RefinedWebModel": FalconDecoder, - } + _registry = registry.decoders @classmethod def from_repo( @@ -235,14 +227,7 @@ class AutoCausalLM(AutoModel[CausalLMModule[TransformerConfig, KeyValueCache]]): Causal LM model loaded from the Hugging Face Model Hub. """ - _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = { - "falcon": FalconCausalLM, - "gpt_neox": GPTNeoXCausalLM, - "llama": LlamaCausalLM, - "mpt": MPTCausalLM, - "RefinedWeb": FalconCausalLM, - "RefinedWebModel": FalconCausalLM, - } + _registry: Registry = registry.causal_lms @classmethod def from_repo( diff --git a/curated_transformers/models/bert/encoder.py b/curated_transformers/models/bert/encoder.py index 363b2875..acbdb389 100644 --- a/curated_transformers/models/bert/encoder.py +++ b/curated_transformers/models/bert/encoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -105,6 +105,10 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None) ] ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("bert",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/camembert/encoder.py b/curated_transformers/models/camembert/encoder.py index 0c18c345..6a16919f 100644 --- a/curated_transformers/models/camembert/encoder.py +++ b/curated_transformers/models/camembert/encoder.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -25,3 +25,7 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No The encoder. """ super().__init__(config, device=device) + + @classmethod + def hf_model_types(cls) -> Tuple[str, ...]: + return ("camembert",) diff --git a/curated_transformers/models/falcon/causal_lm.py b/curated_transformers/models/falcon/causal_lm.py index 91860d06..4c5661b2 100644 --- a/curated_transformers/models/falcon/causal_lm.py +++ b/curated_transformers/models/falcon/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Set, Type, TypeVar +from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar import torch from torch import Tensor @@ -52,6 +52,10 @@ def state_dict_from_hf( ) -> Mapping[str, Tensor]: return state_dict_from_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("falcon", "RefinedWeb", "RefinedWebModel") + @classmethod def state_dict_to_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/falcon/decoder.py b/curated_transformers/models/falcon/decoder.py index 478a70d0..6cef0fe1 100644 --- a/curated_transformers/models/falcon/decoder.py +++ b/curated_transformers/models/falcon/decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -86,6 +86,10 @@ def __init__( device=device, ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("falcon", "RefinedWeb", "RefinedWebModel") + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/gpt_neox/causal_lm.py b/curated_transformers/models/gpt_neox/causal_lm.py index eae67a2d..302f304a 100644 --- a/curated_transformers/models/gpt_neox/causal_lm.py +++ b/curated_transformers/models/gpt_neox/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Set, Type, TypeVar +from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar import torch from torch import Tensor @@ -46,6 +46,10 @@ def __init__( device=device, ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("gpt_neox",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/gpt_neox/decoder.py b/curated_transformers/models/gpt_neox/decoder.py index c7eb634a..7c41ea5c 100644 --- a/curated_transformers/models/gpt_neox/decoder.py +++ b/curated_transformers/models/gpt_neox/decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -114,6 +114,10 @@ def __init__( hidden_width, config.layer.layer_norm_eps, device=device ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("gpt_neox",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/hf_hub/mixin.py b/curated_transformers/models/hf_hub/mixin.py index b0807c8c..2ee59ac2 100644 --- a/curated_transformers/models/hf_hub/mixin.py +++ b/curated_transformers/models/hf_hub/mixin.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from fsspec import AbstractFileSystem @@ -168,6 +168,17 @@ def from_hf_hub( quantization_config=quantization_config, ) + @classmethod + @abstractmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + """ + Get the Hugging Face model types supported by this model. + + :returns: + The supported model types. + """ + ... + @abstractmethod def to( self, diff --git a/curated_transformers/models/llama/causal_lm.py b/curated_transformers/models/llama/causal_lm.py index aa9f5320..0487aed1 100644 --- a/curated_transformers/models/llama/causal_lm.py +++ b/curated_transformers/models/llama/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Set, Type, TypeVar +from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar import torch from torch import Tensor @@ -47,6 +47,10 @@ def __init__( device=device, ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("llama",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/llama/decoder.py b/curated_transformers/models/llama/decoder.py index 1e4b7e69..e3e51aa6 100644 --- a/curated_transformers/models/llama/decoder.py +++ b/curated_transformers/models/llama/decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -120,6 +120,10 @@ def __init__( hidden_width, eps=config.layer.layer_norm_eps, device=device ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("llama",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/mpt/causal_lm.py b/curated_transformers/models/mpt/causal_lm.py index dfc9c72a..de2c3916 100644 --- a/curated_transformers/models/mpt/causal_lm.py +++ b/curated_transformers/models/mpt/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, List, Mapping, Optional, Set, Type, TypeVar +from typing import Any, List, Mapping, Optional, Set, Tuple, Type, TypeVar import torch import torch.nn.functional as F @@ -84,6 +84,10 @@ def forward( logits=logits, ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("mpt",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/mpt/decoder.py b/curated_transformers/models/mpt/decoder.py index f6acc05d..f72b9b31 100644 --- a/curated_transformers/models/mpt/decoder.py +++ b/curated_transformers/models/mpt/decoder.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -120,6 +120,10 @@ def layer_norm(): self.output_layer_norm = layer_norm() + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("mpt",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/roberta/encoder.py b/curated_transformers/models/roberta/encoder.py index b9cfc3f3..c642a6fa 100644 --- a/curated_transformers/models/roberta/encoder.py +++ b/curated_transformers/models/roberta/encoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -105,6 +105,10 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No ] ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("roberta",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/xlm_roberta/encoder.py b/curated_transformers/models/xlm_roberta/encoder.py index 300a7a22..ee398958 100644 --- a/curated_transformers/models/xlm_roberta/encoder.py +++ b/curated_transformers/models/xlm_roberta/encoder.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -25,3 +25,7 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No The encoder. """ super().__init__(config, device=device) + + @classmethod + def hf_model_types(cls) -> Tuple[str, ...]: + return ("xlm-roberta",) diff --git a/docs/source/api.rst b/docs/source/api.rst index 03cf662b..4bece367 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,6 +9,7 @@ API decoders causal-lm generation + registries repositories tokenizers quantization diff --git a/docs/source/registries.rst b/docs/source/registries.rst new file mode 100644 index 00000000..49c6cd89 --- /dev/null +++ b/docs/source/registries.rst @@ -0,0 +1,24 @@ +Registries +========== + +All models in Curated Transformers are added to a registry. Each auto class uses +a registry to query which models are available. This mechanism allows +third-party models to hook into the auto classes. This makes it possible to use +construction methods such as ``AutoModel.from_hf_hub`` with third-party models. + +Third-party packages can register models in the ``options.entry_points`` section +of ``setup.cfg``. For example, if the ``models`` module of the +``extra-transformers`` package contains the ``FooCausalLM``, ``BarDecoder``, and +``BazEncoder`` classes, they can be registered in ``setup.cfg`` as follows: + +.. code-block:: ini + + [options.entry_points] + curated_transformers_causal_lms = + extra-transformers.FooCausalLM = extra_transformers.models:FooCausalLM + + curated_transformers_decoders = + extra-transformers.BarDecoder = extra_transformers.models:BarDecoder + + curated_transformers_encoders = + extra-transformers.BazEncoder = extra_transformers.models:BazEncoder diff --git a/requirements.txt b/requirements.txt index 8d0472b6..dd5b846a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +catalogue>=2.0.4,<2.1.0 curated-tokenizers>=0.9.1,<1.0.0 huggingface-hub>=0.14 tokenizers>=0.13.3 diff --git a/setup.cfg b/setup.cfg index 70b92871..29a6d811 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ zip_safe = true include_package_data = true python_requires = >=3.8 install_requires = + catalogue>=2.0.4,<2.1.0 curated-tokenizers>=0.9.1,<1.0.0 huggingface-hub>=0.14 tokenizers>=0.13.3 @@ -24,9 +25,29 @@ quantization = bitsandbytes>=0.40 # bitsandbytes has a dependency on scipy but doesn't # list it as one for pip installs. So, we'll pull that - # in too (until its rectified upstream). + # in too (until it's rectified upstream). scipy>=1.11 +[options.entry_points] +curated_transformers_causal_lms = + curated-transformers.LlamaCausalLM = curated_transformers.models:LlamaCausalLM + curated-transformers.FalconCausalLM = curated_transformers.models:FalconCausalLM + curated-transformers.GPTNeoXCausalLM = curated_transformers.models:GPTNeoXCausalLM + curated-transformers.MPTCausalLM = curated_transformers.models:MPTCausalLM + +curated_transformers_decoders = + curated-transformers.LlamaDecoder = curated_transformers.models:LlamaDecoder + curated-transformers.FalconDecoder = curated_transformers.models:FalconDecoder + curated-transformers.GPTNeoXDecoder = curated_transformers.models:GPTNeoXDecoder + curated-transformers.MPTDecoder = curated_transformers.models:MPTDecoder + +curated_transformers_encoders = + curated-transformers.ALBERTEncoder = curated_transformers.models:ALBERTEncoder + curated-transformers.BERTEncoder = curated_transformers.models:BERTEncoder + curated-transformers.CamemBERTEncoder = curated_transformers.models:CamemBERTEncoder + curated-transformers.RoBERTaEncoder = curated_transformers.models:RoBERTaEncoder + curated-transformers.XLMREncoder = curated_transformers.models:XLMREncoder + [bdist_wheel] universal = true