From 581386c0c061211792c5bf12c9a7b44440abead3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 9 Apr 2024 19:38:09 +0200 Subject: [PATCH] Add support for loading parameters in-place (#370) In some applications (e.g. spaCy Curated Transformers), we may already have constructed the model and we want to load parameters in-place. This change adds in-place versions of the `from_*` class methods. To support this properly, the in-place loaders should not need to access the configuration anymore. So, the Torch dtype deserialization has moved from the `from_repo` method to the configuration deserialization. To support this, all model configurations now also take a `dtype` parameter. --- curated_transformers/models/albert/_hf.py | 3 + curated_transformers/models/albert/config.py | 5 +- curated_transformers/models/bert/_hf.py | 3 + curated_transformers/models/bert/config.py | 5 +- curated_transformers/models/falcon/_hf.py | 4 + curated_transformers/models/falcon/config.py | 5 +- curated_transformers/models/gpt_neox/_hf.py | 3 + .../models/gpt_neox/config.py | 5 +- .../models/hf_hub/conversion.py | 24 ++++ curated_transformers/models/hf_hub/mixin.py | 107 +++++++++++++++--- curated_transformers/models/llama/_hf.py | 3 + curated_transformers/models/llama/config.py | 5 +- curated_transformers/models/mpt/_hf.py | 3 + curated_transformers/models/mpt/config.py | 5 +- curated_transformers/models/roberta/_hf.py | 3 + curated_transformers/models/roberta/config.py | 5 +- 16 files changed, 167 insertions(+), 21 deletions(-) diff --git a/curated_transformers/models/albert/_hf.py b/curated_transformers/models/albert/_hf.py index fdd1ef3e..1880074f 100644 --- a/curated_transformers/models/albert/_hf.py +++ b/curated_transformers/models/albert/_hf.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple +import torch + from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import ( CommonHFKeys, @@ -82,6 +84,7 @@ def conv_n_hidden_groups(config: ALBERTConfig) -> int: HF_CONFIG_KEYS: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [ (CommonHFKeys.ATTENTION_PROBS_DROPOUT_PROB, None), + (CommonHFKeys.DTYPE, HFConfigKeyDefault("float32")), (CommonHFKeys.EMBEDDING_SIZE, None), (CommonHFKeys.HIDDEN_DROPOUT_PROB, None), (CommonHFKeys.HIDDEN_SIZE, None), diff --git a/curated_transformers/models/albert/config.py b/curated_transformers/models/albert/config.py index bd71ea0f..21797128 100644 --- a/curated_transformers/models/albert/config.py +++ b/curated_transformers/models/albert/config.py @@ -50,6 +50,7 @@ class ALBERTConfig(TransformerConfig): def __init__( self, *, + dtype: torch.dtype = torch.float32, embedding_width: int = 128, hidden_width: int = 768, n_layers_per_group: int = 1, @@ -67,6 +68,8 @@ def __init__( layer_norm_eps: float = 1e-12, ): """ + :param dtype: + Data type to use for model parameters. :param embedding_width: Width of the embedding representations. :param hidden_width: @@ -132,5 +135,5 @@ def __init__( n_layers_per_group=n_layers_per_group, n_hidden_groups=n_hidden_groups, ) - self.dtype = torch.float32 + self.dtype = dtype self.model_max_length = model_max_length diff --git a/curated_transformers/models/bert/_hf.py b/curated_transformers/models/bert/_hf.py index 5412a994..8a1f18ff 100644 --- a/curated_transformers/models/bert/_hf.py +++ b/curated_transformers/models/bert/_hf.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple +import torch + from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import ( CommonHFKeys, @@ -63,6 +65,7 @@ HF_CONFIG_KEYS: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [ (CommonHFKeys.ATTENTION_PROBS_DROPOUT_PROB, None), + (CommonHFKeys.DTYPE, HFConfigKeyDefault("float32")), (CommonHFKeys.HIDDEN_DROPOUT_PROB, None), (CommonHFKeys.HIDDEN_SIZE, None), (CommonHFKeys.HIDDEN_ACT, None), diff --git a/curated_transformers/models/bert/config.py b/curated_transformers/models/bert/config.py index 567129d2..1bb2cc53 100644 --- a/curated_transformers/models/bert/config.py +++ b/curated_transformers/models/bert/config.py @@ -25,6 +25,7 @@ class BERTConfig(TransformerConfig): def __init__( self, *, + dtype: torch.dtype = torch.float32, embedding_width: int = 768, hidden_width: int = 768, intermediate_width: int = 3072, @@ -40,6 +41,8 @@ def __init__( layer_norm_eps: float = 1e-12, ): """ + :param dtype: + Data type to use for model parameters. :param embedding_width: Width of the embedding representations. :param hidden_width: @@ -99,5 +102,5 @@ def __init__( layer_norm_eps=layer_norm_eps, dropout_prob=hidden_dropout_prob, ) - self.dtype = torch.float32 + self.dtype = dtype self.model_max_length = model_max_length diff --git a/curated_transformers/models/falcon/_hf.py b/curated_transformers/models/falcon/_hf.py index 8a5f77aa..409299f9 100644 --- a/curated_transformers/models/falcon/_hf.py +++ b/curated_transformers/models/falcon/_hf.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple +import torch + from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import ( CommonCuratedToHFConverters, @@ -148,6 +150,7 @@ def conv_new_decoder_architecture(config: FalconConfig) -> bool: HF_CONFIG_KEYS_REFINED_WEB_MODEL: List[ Tuple[HFConfigKey, Optional[HFConfigKeyDefault]] ] = [ + (CommonHFKeys.DTYPE, HFConfigKeyDefault("bfloat16")), (CommonHFKeys.HIDDEN_SIZE, None), (HFConfigKeys.N_HEAD, None), (HFConfigKeys.N_HEAD_KV, HFConfigKeyDefault(-1)), @@ -165,6 +168,7 @@ def conv_new_decoder_architecture(config: FalconConfig) -> bool: # Corresponds to the mainline implementation for Falcon models # in the `transformers` library. HF_CONFIG_KEYS_FALCON: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [ + (CommonHFKeys.DTYPE, HFConfigKeyDefault("bfloat16")), (CommonHFKeys.HIDDEN_SIZE, None), (HFConfigKeys.NUM_ATTENTION_HEADS, None), (HFConfigKeys.NUM_HEAD_KV, HFConfigKeyDefault(-1)), diff --git a/curated_transformers/models/falcon/config.py b/curated_transformers/models/falcon/config.py index 5f762887..95e61d7c 100644 --- a/curated_transformers/models/falcon/config.py +++ b/curated_transformers/models/falcon/config.py @@ -27,6 +27,7 @@ def __init__( self, *, attention_probs_dropout_prob: float = 0.0, + dtype: torch.dtype = torch.bfloat16, hidden_dropout_prob: float = 0.0, hidden_width: int = 2560, layer_norm_eps: float = 1e-5, @@ -44,6 +45,8 @@ def __init__( """ :param attention_probs_dropout_prob: Dropout to apply after attention. + :param dtype: + Data type to use for model parameters. :param hidden_dropout_prob: Dropout to apply to the hidden and embedding layers. :param hidden_width: @@ -109,5 +112,5 @@ def __init__( layer_norm_eps=layer_norm_eps, n_hidden_layers=n_hidden_layers, ) - self.dtype = torch.bfloat16 + self.dtype = dtype self.new_decoder_architecture = new_decoder_architecture diff --git a/curated_transformers/models/gpt_neox/_hf.py b/curated_transformers/models/gpt_neox/_hf.py index 5533187c..98c759e6 100644 --- a/curated_transformers/models/gpt_neox/_hf.py +++ b/curated_transformers/models/gpt_neox/_hf.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple +import torch + from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import ( CommonHFKeys, @@ -62,6 +64,7 @@ def conv_rotary_embedding_fraction(config: GPTNeoXConfig) -> float: HF_CONFIG_KEYS: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [ + (CommonHFKeys.DTYPE, HFConfigKeyDefault("float16")), (CommonHFKeys.HIDDEN_ACT, None), (CommonHFKeys.HIDDEN_SIZE, None), (CommonHFKeys.INTERMEDIATE_SIZE, None), diff --git a/curated_transformers/models/gpt_neox/config.py b/curated_transformers/models/gpt_neox/config.py index 34456039..82170f4e 100644 --- a/curated_transformers/models/gpt_neox/config.py +++ b/curated_transformers/models/gpt_neox/config.py @@ -26,6 +26,7 @@ def __init__( *, attention_probs_dropout_prob: float = 0.0, activation: Activation = Activation.GELU, + dtype: torch.dtype = torch.float16, hidden_dropout_prob: float = 0.0, hidden_width: int = 2560, intermediate_width: int = 10240, @@ -43,6 +44,8 @@ def __init__( Dropout to apply after attention. :param activation: Activation used by the pointwise feed-forward layers. + :param dtype: + Data type to use for model parameters. :param hidden_dropout_prob: Dropout to apply to the hidden and embedding layers. :param hidden_width: @@ -103,4 +106,4 @@ def __init__( layer_norm_eps=layer_norm_eps, n_hidden_layers=n_hidden_layers, ) - self.dtype = torch.float16 + self.dtype = dtype diff --git a/curated_transformers/models/hf_hub/conversion.py b/curated_transformers/models/hf_hub/conversion.py index c47115aa..8f6f0336 100644 --- a/curated_transformers/models/hf_hub/conversion.py +++ b/curated_transformers/models/hf_hub/conversion.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union +import torch from torch import Tensor from ...layers import Activation @@ -163,6 +164,10 @@ def attention_dropout(config: TransformerConfig) -> float: def activation(config: TransformerConfig) -> str: return config.layer.feedforward.activation.value + @staticmethod + def dtype(config: TransformerConfig) -> str: + return str(torch.float32).split(".")[1] + @staticmethod def embedding_width(config: TransformerConfig) -> int: return config.embedding.embedding_width @@ -200,6 +205,20 @@ def n_positions(config: TransformerConfig) -> Optional[int]: return config.embedding.n_positions +class CommonHFToCuratedConverters: + """ + Common functions to convert Hugging Face config + values to a compatible Curated config format. + """ + + @staticmethod + def dtype(serialized_dtype_str: str) -> Optional[torch.dtype]: + serialized_dtype = getattr(torch, serialized_dtype_str, None) + if not isinstance(serialized_dtype, torch.dtype): + raise ValueError(f"Invalid torch dtype `{serialized_dtype_str}`") + return serialized_dtype + + class CommonHFKeys: """ Common Hugging Face config keys. @@ -212,6 +231,11 @@ class CommonHFKeys: # passing/calling static methods as without a class bound. lambda c: CommonCuratedToHFConverters.attention_dropout(c), ) + DTYPE = HFConfigKey( + "torch_dtype", + ("dtype", lambda h: CommonHFToCuratedConverters.dtype(h)), + lambda c: CommonCuratedToHFConverters.dtype(c), + ) EMBEDDING_SIZE = HFConfigKey( "embedding_size", "embedding_width", diff --git a/curated_transformers/models/hf_hub/mixin.py b/curated_transformers/models/hf_hub/mixin.py index 0782a7ac..d940fb72 100644 --- a/curated_transformers/models/hf_hub/mixin.py +++ b/curated_transformers/models/hf_hub/mixin.py @@ -169,6 +169,38 @@ def from_fsspec( quantization_config=quantization_config, ) + def from_fsspec_( + self: Self, + *, + fs: AbstractFileSystem, + model_path: str, + fsspec_args: Optional[FsspecArgs] = None, + device: Optional[torch.device] = None, + quantization_config: Optional[BitsAndBytesConfig] = None, + ) -> Self: + """ + Load parameters from a fsspec filestytem in-place into the model. + + :param fs: + The filesystem to load the model from. + :param model_path: + The path of the model on the filesystem. + :param fsspec_args: + Implementation-specific keyword arguments to pass to fsspec + filesystem operations. + :param device: + Device on which the model is initialized. + :param quantization_config: + Configuration for loading quantized weights. + :returns: + Module with the parameters loaded. + """ + return self.from_repo_( + repo=FsspecRepository(fs, model_path, fsspec_args), + device=device, + quantization_config=quantization_config, + ) + @classmethod def from_hf_hub( cls: Type[Self], @@ -198,6 +230,34 @@ def from_hf_hub( quantization_config=quantization_config, ) + def from_hf_hub_( + self: Self, + *, + name: str, + revision: str = "main", + device: Optional[torch.device] = None, + quantization_config: Optional[BitsAndBytesConfig] = None, + ) -> Self: + """ + Load parameters from Hugging Face Hub in-place into the model. + + :param name: + Model name. + :param revision: + Model revision. + :param device: + Device on which the model is initialized. + :param quantization_config: + Configuration for loading quantized weights. + :returns: + Module with the parameters loaded. + """ + return self.from_repo_( + repo=HfHubRepository(name=name, revision=revision), + device=device, + quantization_config=quantization_config, + ) + @classmethod @abstractmethod def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool: @@ -237,30 +297,49 @@ def from_repo( model = cls.from_hf_config(hf_config=config, device=torch.device("meta")) assert isinstance(model, Module) + return model.from_repo_( + repo=repo, device=device, quantization_config=quantization_config + ) + + def from_repo_( + self: Self, + *, + repo: Repository, + device: Optional[torch.device] = None, + quantization_config: Optional[BitsAndBytesConfig] = None, + ) -> Self: + """ + Load parameters from a repository in-place into the model. + + :param repository: + The repository to load from. + :param device: + Device on which to initialize the model. + :param quantization_config: + Configuration for loading quantized weights. + :returns: + Loaded model. + """ + model_repo = ModelRepository(repo) + # Convert the model to the expected dtype. - assert isinstance(model, TransformerModule) - dtype: torch.dtype = model.config.dtype - serialized_dtype_str = config.get("torch_dtype") - if serialized_dtype_str is not None: - serialized_dtype = getattr(torch, serialized_dtype_str, None) - if not isinstance(serialized_dtype, torch.dtype): - raise ValueError(f"Invalid torch dtype `{serialized_dtype_str}`") - dtype = serialized_dtype - model.to(dtype=dtype) + assert isinstance(self, TransformerModule) + dtype: torch.dtype = self.config.dtype + self.to(dtype=dtype) # Prepare for quantization. if quantization_config is not None: - tensor2param = prepare_module_for_quantization(model, quantization_config) # type: ignore + tensor2param = prepare_module_for_quantization(self, quantization_config) # type: ignore else: tensor2param = None # Download model and convert HF parameter names to ours. checkpoint_filenames, checkpoint_type = model_repo.model_checkpoints() load_model_from_checkpoints( - model, # type:ignore + self, # type:ignore filepaths=checkpoint_filenames, checkpoint_type=checkpoint_type, - state_dict_converter=cls.convert_hf_state_dict, + state_dict_converter=type(self).convert_hf_state_dict, tensor_to_param_converter=tensor2param, device=device, ) @@ -268,6 +347,6 @@ def from_repo( # Ensure that any non-persistent buffers are also moved to # the correct device. if device is not None: - model.to(device) + self.to(device) - return model + return self diff --git a/curated_transformers/models/llama/_hf.py b/curated_transformers/models/llama/_hf.py index 2342f902..40a7cbb5 100644 --- a/curated_transformers/models/llama/_hf.py +++ b/curated_transformers/models/llama/_hf.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple +import torch + from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import ( CommonCuratedToHFConverters, @@ -77,6 +79,7 @@ def conv_n_attention_keyvalue_heads(config: LlamaConfig) -> float: HF_CONFIG_KEYS: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [ + (CommonHFKeys.DTYPE, HFConfigKeyDefault("float16")), (CommonHFKeys.HIDDEN_ACT, None), (CommonHFKeys.HIDDEN_SIZE, None), (CommonHFKeys.INTERMEDIATE_SIZE, None), diff --git a/curated_transformers/models/llama/config.py b/curated_transformers/models/llama/config.py index 17b372f6..b2a62130 100644 --- a/curated_transformers/models/llama/config.py +++ b/curated_transformers/models/llama/config.py @@ -27,6 +27,7 @@ def __init__( *, attention_probs_dropout_prob: float = 0.0, activation: Activation = Activation.GELU, + dtype: torch.dtype = torch.float16, hidden_dropout_prob: float = 0.0, hidden_width: int = 2560, intermediate_width: int = 10240, @@ -43,6 +44,8 @@ def __init__( Dropout to apply after attention. :param activation: Activation used by the pointwise feed-forward layers. + :param dtype: + Data type to use for model parameters. :param hidden_dropout_prob: Dropout to apply to the hidden and embedding layers. :param hidden_width: @@ -100,4 +103,4 @@ def __init__( layer_norm_eps=rms_norm_eps, n_hidden_layers=n_hidden_layers, ) - self.dtype = torch.float16 + self.dtype = dtype diff --git a/curated_transformers/models/mpt/_hf.py b/curated_transformers/models/mpt/_hf.py index aa71c409..78ca008c 100644 --- a/curated_transformers/models/mpt/_hf.py +++ b/curated_transformers/models/mpt/_hf.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple +import torch + from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import ( CommonCuratedToHFConverters, @@ -94,6 +96,7 @@ def conv_use_bias(config: MPTConfig) -> int: HF_CONFIG_KEYS: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [ + (CommonHFKeys.DTYPE, HFConfigKeyDefault("bfloat16")), (HFConfigKeys.D_MODEL, None), (HFConfigKeys.EXPANSION_RATIO, None), (HFConfigKeys.MAX_SEQ_LEN, None), diff --git a/curated_transformers/models/mpt/config.py b/curated_transformers/models/mpt/config.py index 17a90f73..8c619369 100644 --- a/curated_transformers/models/mpt/config.py +++ b/curated_transformers/models/mpt/config.py @@ -25,6 +25,7 @@ def __init__( *, attention_probs_dropout_prob: float = 0.0, activation: Activation = Activation.GELU, + dtype: torch.dtype = torch.bfloat16, hidden_dropout_prob: float = 0.0, hidden_width: int = 4096, intermediate_width_multiplier: int = 4, @@ -40,6 +41,8 @@ def __init__( Dropout to apply after attention. :param activation: Activation used by the pointwise feed-forward layers. + :param dtype: + Data type to use for model parameters. :param hidden_dropout_prob: Dropout to apply to the hidden and embedding layers. :param hidden_width: @@ -94,5 +97,5 @@ def __init__( layer_norm_eps=layer_norm_eps, n_hidden_layers=n_hidden_layers, ) - self.dtype = torch.bfloat16 + self.dtype = dtype self.model_max_length = model_max_length diff --git a/curated_transformers/models/roberta/_hf.py b/curated_transformers/models/roberta/_hf.py index e105fe50..5bb7a1e9 100644 --- a/curated_transformers/models/roberta/_hf.py +++ b/curated_transformers/models/roberta/_hf.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple +import torch + from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import ( CommonHFKeys, @@ -73,6 +75,7 @@ def conv_padding_id(config: RoBERTaConfig) -> int: HF_CONFIG_KEYS: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [ (CommonHFKeys.ATTENTION_PROBS_DROPOUT_PROB, None), + (CommonHFKeys.DTYPE, HFConfigKeyDefault("float32")), (CommonHFKeys.HIDDEN_DROPOUT_PROB, None), (CommonHFKeys.HIDDEN_SIZE, None), (CommonHFKeys.HIDDEN_ACT, None), diff --git a/curated_transformers/models/roberta/config.py b/curated_transformers/models/roberta/config.py index 46fa3135..ed506b96 100644 --- a/curated_transformers/models/roberta/config.py +++ b/curated_transformers/models/roberta/config.py @@ -16,6 +16,7 @@ class RoBERTaConfig(BERTConfig): def __init__( self, *args, + dtype: torch.dtype = torch.float32, layer_norm_eps=1e-05, n_positions=514, padding_id=1, @@ -24,6 +25,8 @@ def __init__( **kwargs ): """ + :param dtype: + Data type to use for model parameters. :param embedding_width: Width of the embedding representations. :param hidden_width: @@ -63,5 +66,5 @@ def __init__( n_pieces=n_pieces, **kwargs ) - self.dtype = torch.float32 + self.dtype = dtype self.padding_id = padding_id