Skip to content

Commit

Permalink
Register models using catalogue
Browse files Browse the repository at this point in the history
So far we have hardcoded the available encoders/decoders/causal LMs in
the auto classes. This has the downside that the auto classes only work
with models that are provided by Curated Transformers.

This change adds registries for encoders/decoders/causal LMs. The auto
classes query the relevant registry and check which registered model
supports the downloaded model (through the `hf_model_types` method of
the `FromHFHub` mixin). This makes it possible to register external
models with Curated Transformers, so that they can also be used with the
auto classes.

Adding registries for tokenizers and generators is deferred to future
PRs.
  • Loading branch information
danieldk committed Oct 5, 2023
1 parent 93cf07f commit fcbdfef
Show file tree
Hide file tree
Showing 20 changed files with 172 additions and 58 deletions.
19 changes: 19 additions & 0 deletions curated_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import catalogue
from catalogue import Registry


class registry(object):
"""
Registry for models. These registries are used by auto classes to
discover the available models.
"""

causal_lms: Registry = catalogue.create(
"curated_transformers", "causal_lms", entry_points=True
)
decoders: Registry = catalogue.create(
"curated_transformers", "decoders", entry_points=True
)
encoders: Registry = catalogue.create(
"curated_transformers", "encoders", entry_points=True
)
6 changes: 5 additions & 1 deletion curated_transformers/models/albert/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -99,6 +99,10 @@ def forward(

return ModelOutput(all_outputs=[embeddings, *layer_outputs])

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("albert",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
71 changes: 28 additions & 43 deletions curated_transformers/models/auto_model.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Generic, Optional, Type, TypeVar
from typing import Generic, Optional, Type, TypeVar

import torch
from catalogue import Registry
from fsspec import AbstractFileSystem

from curated_transformers import registry

from ..layers.cache import KeyValueCache
from ..quantization.bnb.config import BitsAndBytesConfig
from ..repository.fsspec import FsspecArgs, FsspecRepository
from ..repository.hf_hub import HfHubRepository
from ..repository.repository import ModelRepository, Repository
from .albert import ALBERTEncoder
from .bert import BERTEncoder
from .camembert import CamemBERTEncoder
from .config import TransformerConfig
from .falcon import FalconCausalLM, FalconDecoder
from .gpt_neox import GPTNeoXCausalLM, GPTNeoXDecoder
from .hf_hub import FromHFHub
from .llama import LlamaCausalLM, LlamaDecoder
from .module import CausalLMModule, DecoderModule, EncoderModule
from .mpt.causal_lm import MPTCausalLM
from .mpt.decoder import MPTDecoder
from .roberta import RoBERTaEncoder
from .xlm_roberta import XLMREncoder

ModelT = TypeVar("ModelT")

Expand All @@ -32,22 +26,33 @@ class AutoModel(ABC, Generic[ModelT]):
Face Model Hub.
"""

_hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {}
_registry: Registry

@classmethod
def _resolve_model_cls(
cls,
repo: ModelRepository,
) -> Type[FromHFHub]:
model_type = repo.model_type()
module_cls = cls._hf_model_type_to_curated.get(model_type)
if module_cls is None:
raise ValueError(
f"Unsupported model type `{model_type}` for {cls.__name__}. "
f"Supported model types: {tuple(cls._hf_model_type_to_curated.keys())}"
)
assert issubclass(module_cls, FromHFHub)
return module_cls

supported_model_types = set()
for entrypoint, module_cls in cls._registry.get_entry_points().items():
if not issubclass(module_cls, FromHFHub):
warnings.warn(
f"Entry point {entrypoint} cannot load from Huggingface Hub "
"since the FromHFHub mixin is not implemented"
)
continue

module_model_types = module_cls.hf_model_types()
if model_type in module_model_types:
return module_cls
supported_model_types.update(module_model_types)

raise ValueError(
f"Unsupported model type `{model_type}` for {cls.__name__}. "
f"Supported model types: {', '.join(sorted(supported_model_types))}"
)

@classmethod
def _instantiate_model(
Expand Down Expand Up @@ -182,13 +187,7 @@ class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]):
Encoder model loaded from the Hugging Face Model Hub.
"""

_hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {
"bert": BERTEncoder,
"albert": ALBERTEncoder,
"camembert": CamemBERTEncoder,
"roberta": RoBERTaEncoder,
"xlm-roberta": XLMREncoder,
}
_registry: Registry = registry.encoders

@classmethod
def from_repo(
Expand All @@ -208,14 +207,7 @@ class AutoDecoder(AutoModel[DecoderModule[TransformerConfig, KeyValueCache]]):
Decoder module loaded from the Hugging Face Model Hub.
"""

_hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {
"falcon": FalconDecoder,
"gpt_neox": GPTNeoXDecoder,
"llama": LlamaDecoder,
"mpt": MPTDecoder,
"RefinedWeb": FalconDecoder,
"RefinedWebModel": FalconDecoder,
}
_registry = registry.decoders

@classmethod
def from_repo(
Expand All @@ -235,14 +227,7 @@ class AutoCausalLM(AutoModel[CausalLMModule[TransformerConfig, KeyValueCache]]):
Causal LM model loaded from the Hugging Face Model Hub.
"""

_hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {
"falcon": FalconCausalLM,
"gpt_neox": GPTNeoXCausalLM,
"llama": LlamaCausalLM,
"mpt": MPTCausalLM,
"RefinedWeb": FalconCausalLM,
"RefinedWebModel": FalconCausalLM,
}
_registry: Registry = registry.causal_lms

@classmethod
def from_repo(
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/bert/encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -105,6 +105,10 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None)
]
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("bert",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/camembert/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple

import torch

Expand All @@ -25,3 +25,7 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No
The encoder.
"""
super().__init__(config, device=device)

@classmethod
def hf_model_types(cls) -> Tuple[str, ...]:
return ("camembert",)
6 changes: 5 additions & 1 deletion curated_transformers/models/falcon/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Set, Type, TypeVar
from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -52,6 +52,10 @@ def state_dict_from_hf(
) -> Mapping[str, Tensor]:
return state_dict_from_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("falcon", "RefinedWeb", "RefinedWebModel")

@classmethod
def state_dict_to_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/falcon/decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -86,6 +86,10 @@ def __init__(
device=device,
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("falcon", "RefinedWeb", "RefinedWebModel")

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/gpt_neox/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Set, Type, TypeVar
from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -46,6 +46,10 @@ def __init__(
device=device,
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("gpt_neox",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/gpt_neox/decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -114,6 +114,10 @@ def __init__(
hidden_width, config.layer.layer_norm_eps, device=device
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("gpt_neox",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
13 changes: 12 additions & 1 deletion curated_transformers/models/hf_hub/mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from fsspec import AbstractFileSystem
Expand Down Expand Up @@ -168,6 +168,17 @@ def from_hf_hub(
quantization_config=quantization_config,
)

@classmethod
@abstractmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
"""
Get the Hugging Face model types supported by this model.
:returns:
The supported model types.
"""
...

@abstractmethod
def to(
self,
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/llama/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Set, Type, TypeVar
from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -47,6 +47,10 @@ def __init__(
device=device,
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("llama",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/llama/decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -120,6 +120,10 @@ def __init__(
hidden_width, eps=config.layer.layer_norm_eps, device=device
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("llama",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/mpt/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Mapping, Optional, Set, Type, TypeVar
from typing import Any, List, Mapping, Optional, Set, Tuple, Type, TypeVar

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -84,6 +84,10 @@ def forward(
logits=logits,
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("mpt",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/mpt/decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -120,6 +120,10 @@ def layer_norm():

self.output_layer_norm = layer_norm()

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("mpt",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/roberta/encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -105,6 +105,10 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No
]
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("roberta",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/xlm_roberta/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple

import torch

Expand All @@ -25,3 +25,7 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No
The encoder.
"""
super().__init__(config, device=device)

@classmethod
def hf_model_types(cls) -> Tuple[str, ...]:
return ("xlm-roberta",)
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ API
decoders
causal-lm
generation
registries
repositories
tokenizers
quantization
Expand Down
Loading

0 comments on commit fcbdfef

Please sign in to comment.