Skip to content

Commit

Permalink
AutoModel: let models check if the configuration is supported (#352)
Browse files Browse the repository at this point in the history
* AutoModel: let models check if the configuration is supported

This will allow us to split up more complex models like Falcon into
multiple classes and corresponding entry points.

* Doc fix

Co-authored-by: Madeesh Kannan <[email protected]>

---------

Co-authored-by: Madeesh Kannan <[email protected]>
  • Loading branch information
danieldk and shadeMe committed Oct 5, 2023
1 parent 3c25097 commit dfe6d96
Show file tree
Hide file tree
Showing 15 changed files with 68 additions and 52 deletions.
6 changes: 3 additions & 3 deletions curated_transformers/models/albert/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -100,8 +100,8 @@ def forward(
return ModelOutput(all_outputs=[embeddings, *layer_outputs])

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("albert",)
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
return config.get("model_type") == "albert"

@classmethod
def state_dict_from_hf(
Expand Down
29 changes: 21 additions & 8 deletions curated_transformers/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..repository.repository import ModelRepository, Repository
from .config import TransformerConfig
from .hf_hub import FromHFHub
from .module import CausalLMModule, DecoderModule, EncoderModule
from .module import CausalLMModule, DecoderModule, EncoderModule, TransformerModule

ModelT = TypeVar("ModelT")

Expand All @@ -26,16 +26,16 @@ class AutoModel(ABC, Generic[ModelT]):
Face Model Hub.
"""

_base_cls: Type[TransformerModule]
_registry: Registry

@classmethod
def _resolve_model_cls(
cls,
repo: ModelRepository,
) -> Type[FromHFHub]:
model_type = repo.model_type()
config = repo.model_config()

supported_model_types = set()
for entrypoint, module_cls in cls._registry.get_entry_points().items():
if not issubclass(module_cls, FromHFHub):
warnings.warn(
Expand All @@ -44,14 +44,24 @@ def _resolve_model_cls(
)
continue

module_model_types = module_cls.hf_model_types()
if model_type in module_model_types:
if not issubclass(module_cls, cls._base_cls):
warnings.warn(
f"Entry point `{entrypoint}` cannot be used by `{cls.__name__}` "
f"since it does does not have `{cls._base_cls.__name__}` "
"as its base class"
)
continue

if module_cls.is_supported(config):
return module_cls
supported_model_types.update(module_model_types)

entrypoints = {
entrypoint for entrypoint in cls._registry.get_entry_points().keys()
}

raise ValueError(
f"Unsupported model type `{model_type}` for {cls.__name__}. "
f"Supported model types: {', '.join(sorted(supported_model_types))}"
f"Unsupported model type for `{cls.__name__}`. "
f"Registered models: {', '.join(sorted(entrypoints))}"
)

@classmethod
Expand Down Expand Up @@ -187,6 +197,7 @@ class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]):
Encoder model loaded from the Hugging Face Model Hub.
"""

_base_cls = EncoderModule
_registry: Registry = registry.encoders

@classmethod
Expand All @@ -207,6 +218,7 @@ class AutoDecoder(AutoModel[DecoderModule[TransformerConfig, KeyValueCache]]):
Decoder module loaded from the Hugging Face Model Hub.
"""

_base_cls = DecoderModule
_registry = registry.decoders

@classmethod
Expand All @@ -227,6 +239,7 @@ class AutoCausalLM(AutoModel[CausalLMModule[TransformerConfig, KeyValueCache]]):
Causal LM model loaded from the Hugging Face Model Hub.
"""

_base_cls = CausalLMModule
_registry: Registry = registry.causal_lms

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/models/bert/encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -114,8 +114,8 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None)
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("bert",)
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
return config.get("model_type") == "bert"

@classmethod
def state_dict_from_hf(
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/models/camembert/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple

import torch

Expand Down Expand Up @@ -27,5 +27,5 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No
super().__init__(config, device=device)

@classmethod
def hf_model_types(cls) -> Tuple[str, ...]:
return ("camembert",)
def is_supported(cls, config: Dict[str, Any]) -> bool:
return config.get("model_type") == "camembert"
6 changes: 3 additions & 3 deletions curated_transformers/models/falcon/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, Optional, Set, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -53,8 +53,8 @@ def state_dict_from_hf(
return state_dict_from_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("falcon", "RefinedWeb", "RefinedWebModel")
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
return config.get("model_type") in ("falcon", "RefinedWeb", "RefinedWebModel")

@classmethod
def state_dict_to_hf(
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/models/falcon/decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -88,8 +88,8 @@ def __init__(
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("falcon", "RefinedWeb", "RefinedWebModel")
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
return config.get("model_type") in ("falcon", "RefinedWeb", "RefinedWebModel")

@classmethod
def state_dict_from_hf(
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/models/gpt_neox/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, Optional, Set, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -47,8 +47,8 @@ def __init__(
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("gpt_neox",)
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
return config.get("model_type") == "gpt_neox"

@classmethod
def state_dict_from_hf(
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/models/gpt_neox/decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -123,8 +123,8 @@ def __init__(
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("gpt_neox",)
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
return config.get("model_type") == "gpt_neox"

@classmethod
def state_dict_from_hf(
Expand Down
13 changes: 8 additions & 5 deletions curated_transformers/models/hf_hub/mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar

import torch
from fsspec import AbstractFileSystem
Expand Down Expand Up @@ -170,14 +170,17 @@ def from_hf_hub(

@classmethod
@abstractmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
"""
Get the Hugging Face model types supported by this model.
Check if the model with the given configuration is supported by this
class.
:param config:
Hugging Face model configuration.
:returns:
The supported model types.
Whether the model is supported by this class.
"""
...
raise NotImplementedError

@abstractmethod
def to(
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/models/llama/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, Optional, Set, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -48,8 +48,8 @@ def __init__(
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("llama",)
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
return config.get("model_type") == "llama"

@classmethod
def state_dict_from_hf(
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/models/llama/decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -129,8 +129,8 @@ def __init__(
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("llama",)
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
return config.get("model_type") == "llama"

@classmethod
def state_dict_from_hf(
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/models/mpt/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Mapping, Optional, Set, Tuple, Type, TypeVar
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -85,8 +85,8 @@ def forward(
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("mpt",)
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
return config.get("model_type") == "mpt"

@classmethod
def state_dict_from_hf(
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/models/mpt/decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -124,8 +124,8 @@ def layer_norm():
self.output_layer_norm = layer_norm()

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("mpt",)
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
return config.get("model_type") == "mpt"

@classmethod
def state_dict_from_hf(
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/models/roberta/encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -114,8 +114,8 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("roberta",)
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
return config.get("model_type") == "roberta"

@classmethod
def state_dict_from_hf(
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/models/xlm_roberta/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple

import torch

Expand Down Expand Up @@ -27,5 +27,5 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No
super().__init__(config, device=device)

@classmethod
def hf_model_types(cls) -> Tuple[str, ...]:
return ("xlm-roberta",)
def is_supported(cls, config: Dict[str, Any]) -> bool:
return config.get("model_type") == "xlm-roberta"

0 comments on commit dfe6d96

Please sign in to comment.