Skip to content

Commit

Permalink
Add support for loading parameters in-place (#370)
Browse files Browse the repository at this point in the history
In some applications (e.g. spaCy Curated Transformers), we may already
have constructed the model and we want to load parameters in-place. This
change adds in-place versions of the `from_*` class methods.

To support this properly, the in-place loaders should not need to access
the configuration anymore. So, the Torch dtype deserialization has moved
from the `from_repo` method to the configuration deserialization. To
support this, all model configurations now also take a `dtype`
parameter.
  • Loading branch information
danieldk committed Apr 9, 2024
1 parent 7d937cc commit 581386c
Show file tree
Hide file tree
Showing 16 changed files with 167 additions and 21 deletions.
3 changes: 3 additions & 0 deletions curated_transformers/models/albert/_hf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, List, Mapping, Optional, Tuple

import torch

from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import (
CommonHFKeys,
Expand Down Expand Up @@ -82,6 +84,7 @@ def conv_n_hidden_groups(config: ALBERTConfig) -> int:

HF_CONFIG_KEYS: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [
(CommonHFKeys.ATTENTION_PROBS_DROPOUT_PROB, None),
(CommonHFKeys.DTYPE, HFConfigKeyDefault("float32")),
(CommonHFKeys.EMBEDDING_SIZE, None),
(CommonHFKeys.HIDDEN_DROPOUT_PROB, None),
(CommonHFKeys.HIDDEN_SIZE, None),
Expand Down
5 changes: 4 additions & 1 deletion curated_transformers/models/albert/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ALBERTConfig(TransformerConfig):
def __init__(
self,
*,
dtype: torch.dtype = torch.float32,
embedding_width: int = 128,
hidden_width: int = 768,
n_layers_per_group: int = 1,
Expand All @@ -67,6 +68,8 @@ def __init__(
layer_norm_eps: float = 1e-12,
):
"""
:param dtype:
Data type to use for model parameters.
:param embedding_width:
Width of the embedding representations.
:param hidden_width:
Expand Down Expand Up @@ -132,5 +135,5 @@ def __init__(
n_layers_per_group=n_layers_per_group,
n_hidden_groups=n_hidden_groups,
)
self.dtype = torch.float32
self.dtype = dtype
self.model_max_length = model_max_length
3 changes: 3 additions & 0 deletions curated_transformers/models/bert/_hf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, List, Mapping, Optional, Tuple

import torch

from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import (
CommonHFKeys,
Expand Down Expand Up @@ -63,6 +65,7 @@

HF_CONFIG_KEYS: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [
(CommonHFKeys.ATTENTION_PROBS_DROPOUT_PROB, None),
(CommonHFKeys.DTYPE, HFConfigKeyDefault("float32")),
(CommonHFKeys.HIDDEN_DROPOUT_PROB, None),
(CommonHFKeys.HIDDEN_SIZE, None),
(CommonHFKeys.HIDDEN_ACT, None),
Expand Down
5 changes: 4 additions & 1 deletion curated_transformers/models/bert/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class BERTConfig(TransformerConfig):
def __init__(
self,
*,
dtype: torch.dtype = torch.float32,
embedding_width: int = 768,
hidden_width: int = 768,
intermediate_width: int = 3072,
Expand All @@ -40,6 +41,8 @@ def __init__(
layer_norm_eps: float = 1e-12,
):
"""
:param dtype:
Data type to use for model parameters.
:param embedding_width:
Width of the embedding representations.
:param hidden_width:
Expand Down Expand Up @@ -99,5 +102,5 @@ def __init__(
layer_norm_eps=layer_norm_eps,
dropout_prob=hidden_dropout_prob,
)
self.dtype = torch.float32
self.dtype = dtype
self.model_max_length = model_max_length
4 changes: 4 additions & 0 deletions curated_transformers/models/falcon/_hf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, List, Mapping, Optional, Tuple

import torch

from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import (
CommonCuratedToHFConverters,
Expand Down Expand Up @@ -148,6 +150,7 @@ def conv_new_decoder_architecture(config: FalconConfig) -> bool:
HF_CONFIG_KEYS_REFINED_WEB_MODEL: List[
Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]
] = [
(CommonHFKeys.DTYPE, HFConfigKeyDefault("bfloat16")),
(CommonHFKeys.HIDDEN_SIZE, None),
(HFConfigKeys.N_HEAD, None),
(HFConfigKeys.N_HEAD_KV, HFConfigKeyDefault(-1)),
Expand All @@ -165,6 +168,7 @@ def conv_new_decoder_architecture(config: FalconConfig) -> bool:
# Corresponds to the mainline implementation for Falcon models
# in the `transformers` library.
HF_CONFIG_KEYS_FALCON: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [
(CommonHFKeys.DTYPE, HFConfigKeyDefault("bfloat16")),
(CommonHFKeys.HIDDEN_SIZE, None),
(HFConfigKeys.NUM_ATTENTION_HEADS, None),
(HFConfigKeys.NUM_HEAD_KV, HFConfigKeyDefault(-1)),
Expand Down
5 changes: 4 additions & 1 deletion curated_transformers/models/falcon/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
self,
*,
attention_probs_dropout_prob: float = 0.0,
dtype: torch.dtype = torch.bfloat16,
hidden_dropout_prob: float = 0.0,
hidden_width: int = 2560,
layer_norm_eps: float = 1e-5,
Expand All @@ -44,6 +45,8 @@ def __init__(
"""
:param attention_probs_dropout_prob:
Dropout to apply after attention.
:param dtype:
Data type to use for model parameters.
:param hidden_dropout_prob:
Dropout to apply to the hidden and embedding layers.
:param hidden_width:
Expand Down Expand Up @@ -109,5 +112,5 @@ def __init__(
layer_norm_eps=layer_norm_eps,
n_hidden_layers=n_hidden_layers,
)
self.dtype = torch.bfloat16
self.dtype = dtype
self.new_decoder_architecture = new_decoder_architecture
3 changes: 3 additions & 0 deletions curated_transformers/models/gpt_neox/_hf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, List, Mapping, Optional, Tuple

import torch

from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import (
CommonHFKeys,
Expand Down Expand Up @@ -62,6 +64,7 @@ def conv_rotary_embedding_fraction(config: GPTNeoXConfig) -> float:


HF_CONFIG_KEYS: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [
(CommonHFKeys.DTYPE, HFConfigKeyDefault("float16")),
(CommonHFKeys.HIDDEN_ACT, None),
(CommonHFKeys.HIDDEN_SIZE, None),
(CommonHFKeys.INTERMEDIATE_SIZE, None),
Expand Down
5 changes: 4 additions & 1 deletion curated_transformers/models/gpt_neox/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
*,
attention_probs_dropout_prob: float = 0.0,
activation: Activation = Activation.GELU,
dtype: torch.dtype = torch.float16,
hidden_dropout_prob: float = 0.0,
hidden_width: int = 2560,
intermediate_width: int = 10240,
Expand All @@ -43,6 +44,8 @@ def __init__(
Dropout to apply after attention.
:param activation:
Activation used by the pointwise feed-forward layers.
:param dtype:
Data type to use for model parameters.
:param hidden_dropout_prob:
Dropout to apply to the hidden and embedding layers.
:param hidden_width:
Expand Down Expand Up @@ -103,4 +106,4 @@ def __init__(
layer_norm_eps=layer_norm_eps,
n_hidden_layers=n_hidden_layers,
)
self.dtype = torch.float16
self.dtype = dtype
24 changes: 24 additions & 0 deletions curated_transformers/models/hf_hub/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import torch
from torch import Tensor

from ...layers import Activation
Expand Down Expand Up @@ -163,6 +164,10 @@ def attention_dropout(config: TransformerConfig) -> float:
def activation(config: TransformerConfig) -> str:
return config.layer.feedforward.activation.value

@staticmethod
def dtype(config: TransformerConfig) -> str:
return str(torch.float32).split(".")[1]

@staticmethod
def embedding_width(config: TransformerConfig) -> int:
return config.embedding.embedding_width
Expand Down Expand Up @@ -200,6 +205,20 @@ def n_positions(config: TransformerConfig) -> Optional[int]:
return config.embedding.n_positions


class CommonHFToCuratedConverters:
"""
Common functions to convert Hugging Face config
values to a compatible Curated config format.
"""

@staticmethod
def dtype(serialized_dtype_str: str) -> Optional[torch.dtype]:
serialized_dtype = getattr(torch, serialized_dtype_str, None)
if not isinstance(serialized_dtype, torch.dtype):
raise ValueError(f"Invalid torch dtype `{serialized_dtype_str}`")
return serialized_dtype


class CommonHFKeys:
"""
Common Hugging Face config keys.
Expand All @@ -212,6 +231,11 @@ class CommonHFKeys:
# passing/calling static methods as without a class bound.
lambda c: CommonCuratedToHFConverters.attention_dropout(c),
)
DTYPE = HFConfigKey(
"torch_dtype",
("dtype", lambda h: CommonHFToCuratedConverters.dtype(h)),
lambda c: CommonCuratedToHFConverters.dtype(c),
)
EMBEDDING_SIZE = HFConfigKey(
"embedding_size",
"embedding_width",
Expand Down
107 changes: 93 additions & 14 deletions curated_transformers/models/hf_hub/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,38 @@ def from_fsspec(
quantization_config=quantization_config,
)

def from_fsspec_(
self: Self,
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[FsspecArgs] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> Self:
"""
Load parameters from a fsspec filestytem in-place into the model.
:param fs:
The filesystem to load the model from.
:param model_path:
The path of the model on the filesystem.
:param fsspec_args:
Implementation-specific keyword arguments to pass to fsspec
filesystem operations.
:param device:
Device on which the model is initialized.
:param quantization_config:
Configuration for loading quantized weights.
:returns:
Module with the parameters loaded.
"""
return self.from_repo_(
repo=FsspecRepository(fs, model_path, fsspec_args),
device=device,
quantization_config=quantization_config,
)

@classmethod
def from_hf_hub(
cls: Type[Self],
Expand Down Expand Up @@ -198,6 +230,34 @@ def from_hf_hub(
quantization_config=quantization_config,
)

def from_hf_hub_(
self: Self,
*,
name: str,
revision: str = "main",
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> Self:
"""
Load parameters from Hugging Face Hub in-place into the model.
:param name:
Model name.
:param revision:
Model revision.
:param device:
Device on which the model is initialized.
:param quantization_config:
Configuration for loading quantized weights.
:returns:
Module with the parameters loaded.
"""
return self.from_repo_(
repo=HfHubRepository(name=name, revision=revision),
device=device,
quantization_config=quantization_config,
)

@classmethod
@abstractmethod
def is_supported(cls: Type[Self], config: Dict[str, Any]) -> bool:
Expand Down Expand Up @@ -237,37 +297,56 @@ def from_repo(
model = cls.from_hf_config(hf_config=config, device=torch.device("meta"))
assert isinstance(model, Module)

return model.from_repo_(
repo=repo, device=device, quantization_config=quantization_config
)

def from_repo_(
self: Self,
*,
repo: Repository,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> Self:
"""
Load parameters from a repository in-place into the model.
:param repository:
The repository to load from.
:param device:
Device on which to initialize the model.
:param quantization_config:
Configuration for loading quantized weights.
:returns:
Loaded model.
"""
model_repo = ModelRepository(repo)

# Convert the model to the expected dtype.
assert isinstance(model, TransformerModule)
dtype: torch.dtype = model.config.dtype
serialized_dtype_str = config.get("torch_dtype")
if serialized_dtype_str is not None:
serialized_dtype = getattr(torch, serialized_dtype_str, None)
if not isinstance(serialized_dtype, torch.dtype):
raise ValueError(f"Invalid torch dtype `{serialized_dtype_str}`")
dtype = serialized_dtype
model.to(dtype=dtype)
assert isinstance(self, TransformerModule)
dtype: torch.dtype = self.config.dtype
self.to(dtype=dtype)

# Prepare for quantization.
if quantization_config is not None:
tensor2param = prepare_module_for_quantization(model, quantization_config) # type: ignore
tensor2param = prepare_module_for_quantization(self, quantization_config) # type: ignore
else:
tensor2param = None

# Download model and convert HF parameter names to ours.
checkpoint_filenames, checkpoint_type = model_repo.model_checkpoints()
load_model_from_checkpoints(
model, # type:ignore
self, # type:ignore
filepaths=checkpoint_filenames,
checkpoint_type=checkpoint_type,
state_dict_converter=cls.convert_hf_state_dict,
state_dict_converter=type(self).convert_hf_state_dict,
tensor_to_param_converter=tensor2param,
device=device,
)

# Ensure that any non-persistent buffers are also moved to
# the correct device.
if device is not None:
model.to(device)
self.to(device)

return model
return self
3 changes: 3 additions & 0 deletions curated_transformers/models/llama/_hf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, List, Mapping, Optional, Tuple

import torch

from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import (
CommonCuratedToHFConverters,
Expand Down Expand Up @@ -77,6 +79,7 @@ def conv_n_attention_keyvalue_heads(config: LlamaConfig) -> float:


HF_CONFIG_KEYS: List[Tuple[HFConfigKey, Optional[HFConfigKeyDefault]]] = [
(CommonHFKeys.DTYPE, HFConfigKeyDefault("float16")),
(CommonHFKeys.HIDDEN_ACT, None),
(CommonHFKeys.HIDDEN_SIZE, None),
(CommonHFKeys.INTERMEDIATE_SIZE, None),
Expand Down
Loading

0 comments on commit 581386c

Please sign in to comment.