diff --git a/curated_transformers/models/albert/encoder.py b/curated_transformers/models/albert/encoder.py index 2c4ab449..ccb0190d 100644 --- a/curated_transformers/models/albert/encoder.py +++ b/curated_transformers/models/albert/encoder.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -100,8 +100,8 @@ def forward( return ModelOutput(all_outputs=[embeddings, *layer_outputs]) @classmethod - def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: - return ("albert",) + def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: + return config.get("model_type") == "albert" @classmethod def state_dict_from_hf( diff --git a/curated_transformers/models/auto_model.py b/curated_transformers/models/auto_model.py index 2c2732cd..172c5b1d 100644 --- a/curated_transformers/models/auto_model.py +++ b/curated_transformers/models/auto_model.py @@ -15,7 +15,7 @@ from ..repository.repository import ModelRepository, Repository from .config import TransformerConfig from .hf_hub import FromHFHub -from .module import CausalLMModule, DecoderModule, EncoderModule +from .module import CausalLMModule, DecoderModule, EncoderModule, TransformerModule ModelT = TypeVar("ModelT") @@ -26,6 +26,7 @@ class AutoModel(ABC, Generic[ModelT]): Face Model Hub. """ + _base_cls: Type[TransformerModule] _registry: Registry @classmethod @@ -33,9 +34,8 @@ def _resolve_model_cls( cls, repo: ModelRepository, ) -> Type[FromHFHub]: - model_type = repo.model_type() + config = repo.model_config() - supported_model_types = set() for entrypoint, module_cls in cls._registry.get_entry_points().items(): if not issubclass(module_cls, FromHFHub): warnings.warn( @@ -44,14 +44,24 @@ def _resolve_model_cls( ) continue - module_model_types = module_cls.hf_model_types() - if model_type in module_model_types: + if not issubclass(module_cls, cls._base_cls): + warnings.warn( + f"Entry point `{entrypoint}` cannot be used by `{cls.__name__}` " + f"since it does does not have `{cls._base_cls.__name__}` " + "as its base class" + ) + continue + + if module_cls.is_supported(config): return module_cls - supported_model_types.update(module_model_types) + + entrypoints = { + entrypoint for entrypoint in cls._registry.get_entry_points().keys() + } raise ValueError( - f"Unsupported model type `{model_type}` for {cls.__name__}. " - f"Supported model types: {', '.join(sorted(supported_model_types))}" + f"Unsupported model type for `{cls.__name__}`. " + f"Registered models: {', '.join(sorted(entrypoints))}" ) @classmethod @@ -187,6 +197,7 @@ class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]): Encoder model loaded from the Hugging Face Model Hub. """ + _base_cls = EncoderModule _registry: Registry = registry.encoders @classmethod @@ -207,6 +218,7 @@ class AutoDecoder(AutoModel[DecoderModule[TransformerConfig, KeyValueCache]]): Decoder module loaded from the Hugging Face Model Hub. """ + _base_cls = DecoderModule _registry = registry.decoders @classmethod @@ -227,6 +239,7 @@ class AutoCausalLM(AutoModel[CausalLMModule[TransformerConfig, KeyValueCache]]): Causal LM model loaded from the Hugging Face Model Hub. """ + _base_cls = CausalLMModule _registry: Registry = registry.causal_lms @classmethod diff --git a/curated_transformers/models/bert/encoder.py b/curated_transformers/models/bert/encoder.py index 73d81c82..e62be1ee 100644 --- a/curated_transformers/models/bert/encoder.py +++ b/curated_transformers/models/bert/encoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -114,8 +114,8 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None) ) @classmethod - def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: - return ("bert",) + def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: + return config.get("model_type") == "bert" @classmethod def state_dict_from_hf( diff --git a/curated_transformers/models/camembert/encoder.py b/curated_transformers/models/camembert/encoder.py index 6a16919f..7d5036ba 100644 --- a/curated_transformers/models/camembert/encoder.py +++ b/curated_transformers/models/camembert/encoder.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch @@ -27,5 +27,5 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No super().__init__(config, device=device) @classmethod - def hf_model_types(cls) -> Tuple[str, ...]: - return ("camembert",) + def is_supported(cls, config: Dict[str, Any]) -> bool: + return config.get("model_type") == "camembert" diff --git a/curated_transformers/models/falcon/causal_lm.py b/curated_transformers/models/falcon/causal_lm.py index 4c5661b2..d1c9e89d 100644 --- a/curated_transformers/models/falcon/causal_lm.py +++ b/curated_transformers/models/falcon/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Set, Tuple, Type, TypeVar import torch from torch import Tensor @@ -53,8 +53,8 @@ def state_dict_from_hf( return state_dict_from_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS) @classmethod - def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: - return ("falcon", "RefinedWeb", "RefinedWebModel") + def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: + return config.get("model_type") in ("falcon", "RefinedWeb", "RefinedWebModel") @classmethod def state_dict_to_hf( diff --git a/curated_transformers/models/falcon/decoder.py b/curated_transformers/models/falcon/decoder.py index 659496cd..9c34e2cb 100644 --- a/curated_transformers/models/falcon/decoder.py +++ b/curated_transformers/models/falcon/decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -88,8 +88,8 @@ def __init__( ) @classmethod - def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: - return ("falcon", "RefinedWeb", "RefinedWebModel") + def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: + return config.get("model_type") in ("falcon", "RefinedWeb", "RefinedWebModel") @classmethod def state_dict_from_hf( diff --git a/curated_transformers/models/gpt_neox/causal_lm.py b/curated_transformers/models/gpt_neox/causal_lm.py index 302f304a..45a6991a 100644 --- a/curated_transformers/models/gpt_neox/causal_lm.py +++ b/curated_transformers/models/gpt_neox/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Set, Tuple, Type, TypeVar import torch from torch import Tensor @@ -47,8 +47,8 @@ def __init__( ) @classmethod - def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: - return ("gpt_neox",) + def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: + return config.get("model_type") == "gpt_neox" @classmethod def state_dict_from_hf( diff --git a/curated_transformers/models/gpt_neox/decoder.py b/curated_transformers/models/gpt_neox/decoder.py index 3f6cd19f..d6b0f385 100644 --- a/curated_transformers/models/gpt_neox/decoder.py +++ b/curated_transformers/models/gpt_neox/decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -123,8 +123,8 @@ def __init__( ) @classmethod - def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: - return ("gpt_neox",) + def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: + return config.get("model_type") == "gpt_neox" @classmethod def state_dict_from_hf( diff --git a/curated_transformers/models/hf_hub/mixin.py b/curated_transformers/models/hf_hub/mixin.py index 393d4254..1a3ff673 100644 --- a/curated_transformers/models/hf_hub/mixin.py +++ b/curated_transformers/models/hf_hub/mixin.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar import torch from fsspec import AbstractFileSystem @@ -170,14 +170,17 @@ def from_hf_hub( @classmethod @abstractmethod - def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: """ - Get the Hugging Face model types supported by this model. + Check if the model with the given configuration is supported by this + class. + :param config: + Hugging Face model configuration. :returns: - The supported model types. + Whether the model is supported by this class. """ - ... + raise NotImplementedError @abstractmethod def to( diff --git a/curated_transformers/models/llama/causal_lm.py b/curated_transformers/models/llama/causal_lm.py index 0487aed1..51b77a4b 100644 --- a/curated_transformers/models/llama/causal_lm.py +++ b/curated_transformers/models/llama/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Set, Tuple, Type, TypeVar import torch from torch import Tensor @@ -48,8 +48,8 @@ def __init__( ) @classmethod - def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: - return ("llama",) + def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: + return config.get("model_type") == "llama" @classmethod def state_dict_from_hf( diff --git a/curated_transformers/models/llama/decoder.py b/curated_transformers/models/llama/decoder.py index d660d29b..64863bfd 100644 --- a/curated_transformers/models/llama/decoder.py +++ b/curated_transformers/models/llama/decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -129,8 +129,8 @@ def __init__( ) @classmethod - def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: - return ("llama",) + def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: + return config.get("model_type") == "llama" @classmethod def state_dict_from_hf( diff --git a/curated_transformers/models/mpt/causal_lm.py b/curated_transformers/models/mpt/causal_lm.py index de2c3916..19e26dbd 100644 --- a/curated_transformers/models/mpt/causal_lm.py +++ b/curated_transformers/models/mpt/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, List, Mapping, Optional, Set, Tuple, Type, TypeVar +from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar import torch import torch.nn.functional as F @@ -85,8 +85,8 @@ def forward( ) @classmethod - def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: - return ("mpt",) + def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: + return config.get("model_type") == "mpt" @classmethod def state_dict_from_hf( diff --git a/curated_transformers/models/mpt/decoder.py b/curated_transformers/models/mpt/decoder.py index e05de9c6..0274f120 100644 --- a/curated_transformers/models/mpt/decoder.py +++ b/curated_transformers/models/mpt/decoder.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -124,8 +124,8 @@ def layer_norm(): self.output_layer_norm = layer_norm() @classmethod - def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: - return ("mpt",) + def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: + return config.get("model_type") == "mpt" @classmethod def state_dict_from_hf( diff --git a/curated_transformers/models/roberta/encoder.py b/curated_transformers/models/roberta/encoder.py index d98e980e..b00a16b1 100644 --- a/curated_transformers/models/roberta/encoder.py +++ b/curated_transformers/models/roberta/encoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -114,8 +114,8 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No ) @classmethod - def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: - return ("roberta",) + def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: + return config.get("model_type") == "roberta" @classmethod def state_dict_from_hf( diff --git a/curated_transformers/models/xlm_roberta/encoder.py b/curated_transformers/models/xlm_roberta/encoder.py index ee398958..ecf4a407 100644 --- a/curated_transformers/models/xlm_roberta/encoder.py +++ b/curated_transformers/models/xlm_roberta/encoder.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch @@ -27,5 +27,5 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No super().__init__(config, device=device) @classmethod - def hf_model_types(cls) -> Tuple[str, ...]: - return ("xlm-roberta",) + def is_supported(cls, config: Dict[str, Any]) -> bool: + return config.get("model_type") == "xlm-roberta"