From 93cf07ff93af6ed06e00bb53f1abcef6bfbc0964 Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Wed, 27 Sep 2023 11:02:41 +0200 Subject: [PATCH] Add support for converting Curated Transfomer state dicts Hugging Face compatible state dicts (#332) * Add support for converting Curated Transfomer state dicts to Hugging Face compatible state dicts. * Typo * Expand string transform class names * Raise error during an unexpected reverse substitution * `StringReplace` without `re` * Add note about `removeprefix` * Concat removed prefix * Add `StringTransformations` to expose factory methods for `StringTransform` subclasses * Simplify `StringSubInvertible` and renanme to `StringSub` --- curated_transformers/models/albert/_hf.py | 119 ++++----- curated_transformers/models/albert/encoder.py | 15 +- curated_transformers/models/bert/_hf.py | 121 ++++----- curated_transformers/models/bert/encoder.py | 15 +- curated_transformers/models/falcon/_hf.py | 94 +++---- .../models/falcon/causal_lm.py | 15 +- curated_transformers/models/falcon/decoder.py | 15 +- curated_transformers/models/gpt_neox/_hf.py | 83 ++---- .../models/gpt_neox/causal_lm.py | 15 +- .../models/gpt_neox/decoder.py | 15 +- .../models/hf_hub/__init__.py | 1 + .../models/hf_hub/conversion.py | 101 +++++++ .../models/{hf_hub.py => hf_hub/mixin.py} | 104 +++----- curated_transformers/models/llama/_hf.py | 82 +++--- .../models/llama/causal_lm.py | 15 +- curated_transformers/models/llama/decoder.py | 15 +- curated_transformers/models/mpt/_hf.py | 82 ++---- curated_transformers/models/mpt/causal_lm.py | 15 +- curated_transformers/models/mpt/decoder.py | 15 +- curated_transformers/models/roberta/_hf.py | 105 ++++---- .../models/roberta/encoder.py | 15 +- .../tests/models/albert/test_encoder.py | 14 +- .../tests/models/bert/test_encoder.py | 14 +- .../tests/models/camembert/test_encoder.py | 14 +- .../tests/models/falcon/test_decoder.py | 20 +- .../tests/models/gpt_neox/test_causal_lm.py | 16 +- .../tests/models/gpt_neox/test_decoder.py | 16 +- .../tests/models/llama/test_causal_lm.py | 13 +- .../tests/models/llama/test_decoder.py | 13 +- .../tests/models/mpt/test_causal_lm.py | 14 +- .../tests/models/mpt/test_decoder.py | 14 +- .../tests/models/roberta/test_encoder.py | 14 +- curated_transformers/tests/models/util.py | 41 +++ .../tests/models/xlm_roberta/test_encoder.py | 14 +- curated_transformers/util/string.py | 249 ++++++++++++++++++ docs/source/api-compat.rst | 2 + 36 files changed, 1004 insertions(+), 521 deletions(-) create mode 100644 curated_transformers/models/hf_hub/__init__.py create mode 100644 curated_transformers/models/hf_hub/conversion.py rename curated_transformers/models/{hf_hub.py => hf_hub/mixin.py} (74%) create mode 100644 curated_transformers/util/string.py diff --git a/curated_transformers/models/albert/_hf.py b/curated_transformers/models/albert/_hf.py index dd63632f..bcaa60e2 100644 --- a/curated_transformers/models/albert/_hf.py +++ b/curated_transformers/models/albert/_hf.py @@ -1,25 +1,56 @@ -import re -from types import MappingProxyType -from typing import Any, Callable, Dict, Mapping, Tuple, Union - -from torch import Tensor +from typing import Any, Callable, Dict, List, Tuple, Union from ...layers.activations import Activation -from ..hf_hub import _process_hf_keys +from ...util.string import StringTransform, StringTransformations +from ..hf_hub.conversion import process_hf_keys from .config import ALBERTConfig -HF_KEY_TO_CURATED_KEY = MappingProxyType( - { - "embeddings.word_embeddings.weight": "embeddings.piece_embeddings.weight", - "embeddings.token_type_embeddings.weight": "embeddings.type_embeddings.weight", - "embeddings.position_embeddings.weight": "embeddings.position_embeddings.weight", - "embeddings.LayerNorm.weight": "embeddings.embed_output_layer_norm.weight", - "embeddings.LayerNorm.bias": "embeddings.embed_output_layer_norm.bias", - # Embedding projection - "encoder.embedding_hidden_mapping_in.weight": "embeddings.projection.weight", - "encoder.embedding_hidden_mapping_in.bias": "embeddings.projection.bias", - } -) +# Order-dependent. +HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ + # Prefixes. + StringTransformations.remove_prefix("albert.", reversible=False), + StringTransformations.regex_sub( + (r"^encoder\.(embedding_|albert_layer)", "\\1"), + (r"^(embedding_|albert_layer)", "encoder.\\1"), + ), + # Layer groups + StringTransformations.regex_sub( + (r"^albert_layer_groups\.", "groups."), (r"^groups\.", "albert_layer_groups.") + ), + # Inner layers. + StringTransformations.sub(".albert_layers.", ".group_layers."), + # Attention blocks. + StringTransformations.sub(".attention.", ".mha."), + StringTransformations.sub(".mha.LayerNorm", ".attn_residual_layer_norm"), + StringTransformations.sub(".mha.dense", ".mha.output"), + # Pointwise feed-forward layers. + StringTransformations.sub(".ffn.", ".ffn.intermediate."), + StringTransformations.sub(".ffn_output.", ".ffn.output."), + StringTransformations.sub(".full_layer_layer_norm.", ".ffn_residual_layer_norm."), + # Embeddings. + StringTransformations.replace( + "embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight" + ), + StringTransformations.replace( + "embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight" + ), + StringTransformations.replace( + "embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight" + ), + StringTransformations.replace( + "embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight" + ), + StringTransformations.replace( + "embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias" + ), + # Embedding projection. + StringTransformations.replace( + "embedding_hidden_mapping_in.weight", "embeddings.projection.weight" + ), + StringTransformations.replace( + "embedding_hidden_mapping_in.bias", "embeddings.projection.bias" + ), +] HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = { "attention_probs_dropout_prob": "attention_probs_dropout_prob", @@ -40,55 +71,5 @@ def convert_hf_config(hf_config: Any) -> ALBERTConfig: - kwargs = _process_hf_keys("ALBERT", hf_config, HF_CONFIG_KEY_MAPPING) + kwargs = process_hf_keys("ALBERT", hf_config, HF_CONFIG_KEY_MAPPING) return ALBERTConfig(model_max_length=hf_config["max_position_embeddings"], **kwargs) - - -def convert_hf_state_dict(params: Mapping[str, Tensor]) -> Mapping[str, Tensor]: - # Strip the `albert` prefix from ALBERT model parameters. - stripped_params = {re.sub(r"^albert\.", "", k): v for k, v in params.items()} - - # The ALBERT encoder parameters have the following form: - # - # encoder.albert_layer_groups.{hidden_group}.albert_layers.{inner_layer}.{param_name} - # - # hidden_group is in [0, n_hidden_group) - # inner_layer is in [0, n_layers_per_group) - - out = {} - for name, parameter in stripped_params.items(): - if "encoder.albert_layer" not in name: - continue - - # TODO: Make these substitutions less ugly. - - # Remove the prefix and rename. - name = re.sub(r"^encoder\.", "", name) - - # Layer groups - name = re.sub(r"^albert_layer_groups\.", "groups.", name) - - # Inner layers. - name = re.sub(r"\.albert_layers\.", ".group_layers.", name) - - # Attention blocks. - name = re.sub(r"\.attention\.", ".mha.", name) - name = re.sub(r"\.mha\.LayerNorm", r".attn_residual_layer_norm", name) - name = re.sub(r"\.mha\.dense\.", r".mha.output.", name) - - # Pointwise feed-forward layers. - name = re.sub(r"\.ffn\.", r".ffn.intermediate.", name) - name = re.sub(r"\.ffn_output\.", r".ffn.output.", name) - name = re.sub( - r"\.full_layer_layer_norm\.", - r".ffn_residual_layer_norm.", - name, - ) - - out[name] = parameter - - for hf_name, curated_name in HF_KEY_TO_CURATED_KEY.items(): - if hf_name in stripped_params: - out[curated_name] = stripped_params[hf_name] - - return out diff --git a/curated_transformers/models/albert/encoder.py b/curated_transformers/models/albert/encoder.py index ffaf6489..1a247c23 100644 --- a/curated_transformers/models/albert/encoder.py +++ b/curated_transformers/models/albert/encoder.py @@ -11,9 +11,10 @@ TransformerEmbeddings, ) from ..hf_hub import FromHFHub +from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..module import EncoderModule from ..output import ModelOutput -from ._hf import convert_hf_config, convert_hf_state_dict +from ._hf import HF_PARAM_KEY_TRANSFORMS, convert_hf_config from .config import ALBERTConfig from .layer_group import ALBERTLayerGroup @@ -99,8 +100,16 @@ def forward( return ModelOutput(all_outputs=[embeddings, *layer_outputs]) @classmethod - def convert_hf_state_dict(cls, params: Mapping[str, Tensor]): - return convert_hf_state_dict(params) + def state_dict_from_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_from_hf(params, HF_PARAM_KEY_TRANSFORMS) + + @classmethod + def state_dict_to_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_to_hf(params, HF_PARAM_KEY_TRANSFORMS) @classmethod def from_hf_config( diff --git a/curated_transformers/models/bert/_hf.py b/curated_transformers/models/bert/_hf.py index c876b862..a920c97c 100644 --- a/curated_transformers/models/bert/_hf.py +++ b/curated_transformers/models/bert/_hf.py @@ -1,23 +1,59 @@ -import re -from types import MappingProxyType -from typing import Any, Callable, Dict, Mapping, Tuple, Union - -from torch import Tensor +from typing import Any, Callable, Dict, List, Tuple, Union from ...layers.activations import Activation -from ..hf_hub import _process_hf_keys +from ...util.string import StringTransform, StringTransformations +from ..hf_hub.conversion import process_hf_keys from .config import BERTConfig -HF_KEY_TO_CURATED_KEY = MappingProxyType( - { - "embeddings.word_embeddings.weight": "embeddings.piece_embeddings.weight", - "embeddings.token_type_embeddings.weight": "embeddings.type_embeddings.weight", - "embeddings.position_embeddings.weight": "embeddings.position_embeddings.weight", - "embeddings.LayerNorm.weight": "embeddings.embed_output_layer_norm.weight", - "embeddings.LayerNorm.bias": "embeddings.embed_output_layer_norm.bias", - } -) - +# Order-dependent. +HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ + # Old HF parameter names (one-way transforms). + StringTransformations.regex_sub((r"\.gamma$", ".weight"), backward=None), + StringTransformations.regex_sub((r"\.beta$", ".bias"), backward=None), + # Prefixes. + StringTransformations.remove_prefix("bert.", reversible=False), + StringTransformations.regex_sub( + (r"^encoder\.(layer\.)", "\\1"), + (r"^(layer\.)", "encoder.\\1"), + ), + # Layers. + StringTransformations.regex_sub((r"^layer", "layers"), (r"^layers", "layer")), + # Attention blocks. + StringTransformations.regex_sub( + (r"\.attention\.self\.(query|key|value)", ".mha.\\1"), + (r"\.mha\.(query|key|value)", ".attention.self.\\1"), + ), + StringTransformations.sub(".attention.output.dense", ".mha.output"), + StringTransformations.sub( + r".attention.output.LayerNorm", ".attn_residual_layer_norm" + ), + # Pointwise feed-forward layers. + StringTransformations.sub(".intermediate.dense", ".ffn.intermediate"), + StringTransformations.regex_sub( + (r"(\.\d+)\.output\.LayerNorm", "\\1.ffn_residual_layer_norm"), + (r"(\.\d+)\.ffn_residual_layer_norm", "\\1.output.LayerNorm"), + ), + StringTransformations.regex_sub( + (r"(\.\d+)\.output\.dense", "\\1.ffn.output"), + (r"(\.\d+)\.ffn\.output", "\\1.output.dense"), + ), + # Embeddings. + StringTransformations.replace( + "embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight" + ), + StringTransformations.replace( + "embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight" + ), + StringTransformations.replace( + "embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight" + ), + StringTransformations.replace( + "embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight" + ), + StringTransformations.replace( + "embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias" + ), +] HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = { "attention_probs_dropout_prob": "attention_probs_dropout_prob", @@ -35,61 +71,10 @@ def convert_hf_config(hf_config: Any) -> BERTConfig: - kwargs = _process_hf_keys("BERT", hf_config, HF_CONFIG_KEY_MAPPING) + kwargs = process_hf_keys("BERT", hf_config, HF_CONFIG_KEY_MAPPING) return BERTConfig( embedding_width=hf_config["hidden_size"], model_max_length=hf_config["max_position_embeddings"], **kwargs, ) - - -def convert_hf_state_dict(params: Mapping[str, Tensor]) -> Mapping[str, Tensor]: - out = {} - - renamed_params = _rename_old_hf_names(params) - - # Strip the `bert` prefix from BERT model parameters. - stripped_params = {re.sub(r"^bert\.", "", k): v for k, v in renamed_params.items()} - - for name, parameter in stripped_params.items(): - if "encoder.layer." not in name: - continue - - # TODO: Make these substitutions less ugly. - - # Remove the prefix and rename the internal 'layers' variable. - name = re.sub(r"^encoder\.", "", name) - name = re.sub(r"^layer", "layers", name) - - # The HF model has one more level of indirection for the output layers in their - # attention heads and the feed-forward network layers. - name = re.sub(r"\.attention\.self\.(query|key|value)", r".mha.\1", name) - name = re.sub(r"\.attention\.(output)\.dense", r".mha.\1", name) - name = re.sub( - r"\.attention\.output\.LayerNorm", r".attn_residual_layer_norm", name - ) - name = re.sub(r"\.(intermediate)\.dense", r".ffn.\1", name) - name = re.sub( - r"(\.\d+)\.output\.LayerNorm", r"\1.ffn_residual_layer_norm", name - ) - name = re.sub(r"(\.\d+)\.(output)\.dense", r"\1.ffn.\2", name) - - out[name] = parameter - - for hf_name, curated_name in HF_KEY_TO_CURATED_KEY.items(): - if hf_name in stripped_params: - out[curated_name] = stripped_params[hf_name] - - return out - - -def _rename_old_hf_names( - params: Mapping[str, Tensor], -) -> Mapping[str, Tensor]: - out = {} - for name, parameter in params.items(): - name = re.sub(r"\.gamma$", ".weight", name) - name = re.sub(r"\.beta$", ".bias", name) - out[name] = parameter - return out diff --git a/curated_transformers/models/bert/encoder.py b/curated_transformers/models/bert/encoder.py index 0dc55d1b..363b2875 100644 --- a/curated_transformers/models/bert/encoder.py +++ b/curated_transformers/models/bert/encoder.py @@ -16,8 +16,9 @@ TransformerLayerNorms, ) from ..hf_hub import FromHFHub +from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerEncoder -from ._hf import convert_hf_config, convert_hf_state_dict +from ._hf import HF_PARAM_KEY_TRANSFORMS, convert_hf_config from .config import BERTConfig # Only provided as typing.Self in Python 3.11+. @@ -105,8 +106,16 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None) ) @classmethod - def convert_hf_state_dict(cls, params: Mapping[str, Tensor]): - return convert_hf_state_dict(params) + def state_dict_from_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_from_hf(params, HF_PARAM_KEY_TRANSFORMS) + + @classmethod + def state_dict_to_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_to_hf(params, HF_PARAM_KEY_TRANSFORMS) @classmethod def from_hf_config( diff --git a/curated_transformers/models/falcon/_hf.py b/curated_transformers/models/falcon/_hf.py index d45ccef0..c4d6d045 100644 --- a/curated_transformers/models/falcon/_hf.py +++ b/curated_transformers/models/falcon/_hf.py @@ -1,16 +1,44 @@ -import re -from typing import Any, Callable, Dict, Mapping, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union -from torch import Tensor - -from ..hf_hub import _process_hf_keys -from ..module import DecoderModule +from ...util.string import StringTransform, StringTransformations +from ..hf_hub.conversion import process_hf_keys from .config import FalconConfig ATTENTION_DROPOUT = "attention_probs_dropout_prob" HIDDEN_DROPOUT = "hidden_dropout_prob" EXTRA_KWARG_KEYS = [ATTENTION_DROPOUT, HIDDEN_DROPOUT] + +# Order-dependent. +COMMON_HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ + StringTransformations.regex_sub((r"^h\.", "layers."), (r"^layers\.", "h.")), + StringTransformations.sub("decoder.h.", "decoder.layers."), + # Attention blocks. + StringTransformations.sub(".self_attention", ".mha"), + StringTransformations.sub(".mha.query_key_value", ".mha.input"), + StringTransformations.sub(".mha.dense", ".mha.output"), + # Pointwise feedforward. + StringTransformations.sub(".mlp", ".ffn"), + StringTransformations.sub(".dense_h_to_4h", ".intermediate"), + StringTransformations.sub(".ffn.dense_4h_to_h", ".ffn.output"), + # Layer norms. + StringTransformations.sub(".input_layernorm", ".attn_layer_norm"), + StringTransformations.sub(".ln_attn", ".attn_input_layer_norm"), + StringTransformations.sub(".post_attention_layernorm", ".ffn_layer_norm"), + StringTransformations.sub(".ln_mlp", ".ffn_input_layer_norm"), + StringTransformations.sub("ln_f.", "output_layer_norm."), + # Embeddings. + StringTransformations.sub("word_embeddings.", "embeddings.piece_embeddings."), + StringTransformations.sub("lm_head.", "output_embeddings."), +] + +DECODER_HF_PARAM_KEY_TRANSFORMS = [ + StringTransformations.remove_prefix("transformer.", reversible=False) +] + COMMON_HF_PARAM_KEY_TRANSFORMS +CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS = [ + StringTransformations.sub("transformer.", "decoder."), +] + COMMON_HF_PARAM_KEY_TRANSFORMS + HF_CONFIG_KEY_MAPPING_REFINED_WEB_MODEL: Dict[str, Union[str, Tuple[str, Callable]]] = { "hidden_size": "hidden_width", "layer_norm_epsilon": "layer_norm_eps", @@ -45,7 +73,7 @@ def convert_hf_config(hf_config: Any) -> FalconConfig: def _convert_hf_config_refined_web_model(hf_config: Any) -> FalconConfig: - kwargs = _process_hf_keys( + kwargs = process_hf_keys( "Falcon", hf_config, HF_CONFIG_KEY_MAPPING_REFINED_WEB_MODEL, EXTRA_KWARG_KEYS ) @@ -79,7 +107,7 @@ def _convert_hf_config_refined_web_model(hf_config: Any) -> FalconConfig: def _convert_hf_config_falcon(hf_config: Any) -> FalconConfig: - kwargs = _process_hf_keys( + kwargs = process_hf_keys( "Falcon", hf_config, HF_CONFIG_KEY_MAPPING_FALCON, EXTRA_KWARG_KEYS ) @@ -104,53 +132,3 @@ def _convert_hf_config_falcon(hf_config: Any) -> FalconConfig: rotary_embedding_fraction=1.0, **kwargs, ) - - -def convert_hf_state_dict(cls, params: Mapping[str, Tensor]) -> Mapping[str, Tensor]: - """ - Convert state dict from HF paramater naming to ours. - The function is insensitive to prefixes, to allow loading - both the decoder and the full LM. - """ - if issubclass(cls, DecoderModule): - stripped_params = { - re.sub(r"^transformer\.", "", k): v for k, v in params.items() - } - else: - stripped_params = { - re.sub(r"^transformer\.", "decoder.", k): v for k, v in params.items() - } - - out = {} - for name, parameter in stripped_params.items(): - # These parameters are all created on-the-fly. - if "rotary_emb" in name or "attention.bias" in name or "masked_bias" in name: - continue - - name = re.sub(r"^h\.", "layers.", name) - name = re.sub(r"decoder\.h\.", "decoder.layers.", name) - - # Attention - name = re.sub(r"\.self_attention", r".mha", name) - name = re.sub(r"\.query_key_value", r".input", name) - name = re.sub(r"\.mha\.dense", r".mha.output", name) - - # Pointwise feedforward - name = re.sub(r"\.mlp", r".ffn", name) - name = re.sub(r"\.dense_h_to_4h", r".intermediate", name) - name = re.sub(r"\.dense_4h_to_h", r".output", name) - - # Layer norms - name = re.sub(r"\.input_layernorm", r".attn_layer_norm", name) - name = re.sub(r"\.ln_attn", r".attn_input_layer_norm", name) - name = re.sub(r"\.post_attention_layernorm", r".ffn_layer_norm", name) - name = re.sub(r"\.ln_mlp", r".ffn_input_layer_norm", name) - name = re.sub(r"ln_f\.", r"output_layer_norm.", name) - - # Embeddings - name = re.sub(r"word_embeddings\.", r"embeddings.piece_embeddings.", name) - name = re.sub(r"lm_head\.", r"output_embeddings.", name) - - out[name] = parameter - - return out diff --git a/curated_transformers/models/falcon/causal_lm.py b/curated_transformers/models/falcon/causal_lm.py index ae570d48..91860d06 100644 --- a/curated_transformers/models/falcon/causal_lm.py +++ b/curated_transformers/models/falcon/causal_lm.py @@ -6,8 +6,9 @@ from ...quantization.quantizable import Quantizable from ..hf_hub import FromHFHub +from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerCausalLM -from ._hf import convert_hf_config, convert_hf_state_dict +from ._hf import CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS, convert_hf_config from .config import FalconConfig from .decoder import FalconDecoder @@ -46,8 +47,16 @@ def __init__( ) @classmethod - def convert_hf_state_dict(cls, params: Mapping[str, Tensor]): - return convert_hf_state_dict(cls, params) + def state_dict_from_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_from_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS) + + @classmethod + def state_dict_to_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_to_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS) @classmethod def from_hf_config( diff --git a/curated_transformers/models/falcon/decoder.py b/curated_transformers/models/falcon/decoder.py index 0a16154e..478a70d0 100644 --- a/curated_transformers/models/falcon/decoder.py +++ b/curated_transformers/models/falcon/decoder.py @@ -22,8 +22,9 @@ TransformerLayerNorms, ) from ..hf_hub import FromHFHub +from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerDecoder -from ._hf import convert_hf_config, convert_hf_state_dict +from ._hf import DECODER_HF_PARAM_KEY_TRANSFORMS, convert_hf_config from .config import FalconConfig from .layer import OldFalconDecoderLayer @@ -86,8 +87,16 @@ def __init__( ) @classmethod - def convert_hf_state_dict(cls, params: Mapping[str, Tensor]): - return convert_hf_state_dict(cls, params) + def state_dict_from_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_from_hf(params, DECODER_HF_PARAM_KEY_TRANSFORMS) + + @classmethod + def state_dict_to_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_to_hf(params, DECODER_HF_PARAM_KEY_TRANSFORMS) @classmethod def from_hf_config( diff --git a/curated_transformers/models/gpt_neox/_hf.py b/curated_transformers/models/gpt_neox/_hf.py index e3f21c4b..4f8905d9 100644 --- a/curated_transformers/models/gpt_neox/_hf.py +++ b/curated_transformers/models/gpt_neox/_hf.py @@ -1,17 +1,39 @@ -import re -from typing import Any, Callable, Dict, Mapping, Tuple, Union - -from torch import Tensor +from typing import Any, Callable, Dict, List, Tuple, Union from ...layers.activations import Activation -from ..hf_hub import _process_hf_keys -from ..module import DecoderModule +from ...util.string import StringTransform, StringTransformations +from ..hf_hub.conversion import process_hf_keys from .config import GPTNeoXConfig ATTENTION_DROPOUT = "attention_probs_dropout_prob" HIDDEN_DROPOUT = "hidden_dropout_prob" EXTRA_KWARG_KEYS = [ATTENTION_DROPOUT, HIDDEN_DROPOUT] +# Order-dependent. +COMMON_HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ + StringTransformations.sub("gpt_neox", "decoder"), + # Attention blocks. + StringTransformations.sub(".attention", ".mha"), + StringTransformations.sub(".mha.query_key_value", ".mha.input"), + StringTransformations.sub(".mha.dense", ".mha.output"), + # Pointwise feedforward. + StringTransformations.sub(".mlp", ".ffn"), + StringTransformations.sub(".dense_h_to_4h", ".intermediate"), + StringTransformations.sub(".ffn.dense_4h_to_h", ".ffn.output"), + # Layer norms. + StringTransformations.sub(".input_layernorm", ".attn_input_layer_norm"), + StringTransformations.sub(".post_attention_layernorm", ".ffn_input_layer_norm"), + StringTransformations.sub("final_layer_norm.", "output_layer_norm."), + # Embeddings. + StringTransformations.sub("embed_in.", "embeddings.piece_embeddings."), + StringTransformations.sub("embed_out.", "output_embeddings."), +] + +DECODER_HF_PARAM_KEY_TRANSFORMS = [ + StringTransformations.remove_prefix("gpt_neox.", reversible=False) +] + COMMON_HF_PARAM_KEY_TRANSFORMS +CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS = COMMON_HF_PARAM_KEY_TRANSFORMS + HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = { "hidden_act": ("activation", Activation), "hidden_size": "hidden_width", @@ -27,57 +49,10 @@ def convert_hf_config(hf_config: Any) -> GPTNeoXConfig: - kwargs = _process_hf_keys( + kwargs = process_hf_keys( "GPT-NeoX", hf_config, HF_CONFIG_KEY_MAPPING, EXTRA_KWARG_KEYS ) return GPTNeoXConfig( model_max_length=hf_config["max_position_embeddings"], **kwargs, ) - - -def convert_hf_state_dict(cls, params: Mapping[str, Tensor]) -> Mapping[str, Tensor]: - """Convert state dict from HF paramater naming to ours. - The function is insensitive to prefixes, to allow loading - both the decoder and the full LM.""" - if issubclass(cls, DecoderModule): - stripped_params = { - re.sub(r"^gpt_neox\.", "", k): v - for k, v in params.items() - # The decoder does not the output embeddings, avoid unexpected key. - if k != "embed_out.weight" - } - else: - # Rewrap as dict if necessay to make MyPy happy. - stripped_params = dict(params) - - out = {} - for name, parameter in stripped_params.items(): - # These parameters are all created on-the-fly. - if "rotary_emb" in name or "attention.bias" in name or "masked_bias" in name: - continue - - name = name.replace("gpt_neox", "decoder") - - # Attention - name = re.sub(r"\.attention", r".mha", name) - name = re.sub(r"\.query_key_value", r".input", name) - name = re.sub(r"\.mha\.dense", r".mha.output", name) - - # Pointwise feedforward - name = re.sub(r"\.mlp", r".ffn", name) - name = re.sub(r"\.dense_h_to_4h", r".intermediate", name) - name = re.sub(r"\.dense_4h_to_h", r".output", name) - - # Layer norms - name = re.sub(r"\.input_layernorm", r".attn_input_layer_norm", name) - name = re.sub(r"\.post_attention_layernorm", r".ffn_input_layer_norm", name) - name = re.sub(r"final_layer_norm\.", r"output_layer_norm.", name) - - # Embeddings - name = re.sub(r"embed_in\.", r"embeddings.piece_embeddings.", name) - name = re.sub(r"embed_out\.", r"output_embeddings.", name) - - out[name] = parameter - - return out diff --git a/curated_transformers/models/gpt_neox/causal_lm.py b/curated_transformers/models/gpt_neox/causal_lm.py index cc6f9bb7..eae67a2d 100644 --- a/curated_transformers/models/gpt_neox/causal_lm.py +++ b/curated_transformers/models/gpt_neox/causal_lm.py @@ -6,8 +6,9 @@ from ...quantization import Quantizable from ..hf_hub import FromHFHub +from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerCausalLM -from ._hf import convert_hf_config, convert_hf_state_dict +from ._hf import CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS, convert_hf_config from .config import GPTNeoXConfig from .decoder import GPTNeoXDecoder @@ -46,8 +47,16 @@ def __init__( ) @classmethod - def convert_hf_state_dict(cls, params: Mapping[str, Tensor]): - return convert_hf_state_dict(cls, params) + def state_dict_from_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_from_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS) + + @classmethod + def state_dict_to_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_to_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS) @classmethod def from_hf_config( diff --git a/curated_transformers/models/gpt_neox/decoder.py b/curated_transformers/models/gpt_neox/decoder.py index d04189ca..c7eb634a 100644 --- a/curated_transformers/models/gpt_neox/decoder.py +++ b/curated_transformers/models/gpt_neox/decoder.py @@ -17,8 +17,9 @@ TransformerLayerNorms, ) from ..hf_hub import FromHFHub +from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerDecoder -from ._hf import convert_hf_config, convert_hf_state_dict +from ._hf import DECODER_HF_PARAM_KEY_TRANSFORMS, convert_hf_config from .config import GPTNeoXConfig # Only provided as typing.Self in Python 3.11+. @@ -114,8 +115,16 @@ def __init__( ) @classmethod - def convert_hf_state_dict(cls, params: Mapping[str, Tensor]): - return convert_hf_state_dict(cls, params) + def state_dict_from_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_from_hf(params, DECODER_HF_PARAM_KEY_TRANSFORMS) + + @classmethod + def state_dict_to_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_to_hf(params, DECODER_HF_PARAM_KEY_TRANSFORMS) @classmethod def from_hf_config( diff --git a/curated_transformers/models/hf_hub/__init__.py b/curated_transformers/models/hf_hub/__init__.py new file mode 100644 index 00000000..d0c7bd0f --- /dev/null +++ b/curated_transformers/models/hf_hub/__init__.py @@ -0,0 +1 @@ +from .mixin import FromHFHub diff --git a/curated_transformers/models/hf_hub/conversion.py b/curated_transformers/models/hf_hub/conversion.py new file mode 100644 index 00000000..0b3258a2 --- /dev/null +++ b/curated_transformers/models/hf_hub/conversion.py @@ -0,0 +1,101 @@ +from typing import Any, Callable, Dict, List, Mapping, Tuple, Union + +from torch import Tensor + +from ...util.string import StringTransform, StringTransformations + + +def process_hf_keys( + model_name: str, + hf_config: Dict[str, Any], + hf_to_curated: Dict[str, Union[str, Tuple[str, Callable]]], + extra_keys: List[str] = [], +) -> Dict[str, Any]: + """ + Convert Hugging Face configuration keys to keyword arguments for + Curated Transformers configuration classes. + + :param model_name: + Model name. Only used in exception messages. + :param hf_config: + Hugging Face model configuration. + :param hf_to_curated: + Dictionary that maps Hugging Face configuration keys to keyword + arguments for a Curated Transformers configuration class. If a value + is a tuple, the first tuple element is the name of the keyword + argument class and the second tuple element is a conversion function. + :param extra_keys: + Optional keys for which the Hugging Face configuration key and the + keyword argument of the Curated Transformers configuration class is + the same. + :returns: + Dictionary with keyword arguments. + """ + missing_keys = tuple( + sorted(set(hf_to_curated.keys()).difference(set(hf_config.keys()))) + ) + if len(missing_keys) != 0: + raise ValueError( + f"Missing keys in Hugging Face {model_name} model config: {missing_keys}" + ) + + kwargs = {} + + for hf, curated in hf_to_curated.items(): + if isinstance(curated, tuple): + curated, ctor = curated + else: + ctor = lambda x: x + + kwargs[curated] = ctor(hf_config[hf]) + + # Handle config options that are not set in all models. + kwargs.update({k: hf_config[k] for k in extra_keys if k in hf_config}) + + return kwargs + + +def state_dict_from_hf( + params: Mapping[str, Tensor], transforms: List[StringTransform] +) -> Mapping[str, Tensor]: + """ + Apply transformations to a Hugging Face state dict to make it + compatible with Curated Transformer modules. + + :param params: + Hugging Face state dict. + :param transforms: + List of string transformations for the state dict's keys. + :returns: + Transformed state dict. + """ + out = {} + for key, param in params.items(): + for transform in transforms: + key = transform.apply(key) + out[key] = param + return out + + +def state_dict_to_hf( + params: Mapping[str, Tensor], transforms: List[StringTransform] +) -> Mapping[str, Tensor]: + """ + Apply transformations to a Curated Transformer state dict to make it + compatible with Hugging Face modules. + + :param params: + Curated Transformer state dict. + :param transforms: + List of string transformations for the state dict's keys. + This must be the same transformations that were used to + convert the original Hugging Face state dict. + :returns: + Transformed state dict. + """ + out = {} + for key, param in params.items(): + for transform in transforms[::-1]: + key = transform.revert(key) + out[key] = param + return out diff --git a/curated_transformers/models/hf_hub.py b/curated_transformers/models/hf_hub/mixin.py similarity index 74% rename from curated_transformers/models/hf_hub.py rename to curated_transformers/models/hf_hub/mixin.py index 986c4c4c..b0807c8c 100644 --- a/curated_transformers/models/hf_hub.py +++ b/curated_transformers/models/hf_hub/mixin.py @@ -1,28 +1,17 @@ from abc import ABC, abstractmethod -from typing import ( - Any, - Callable, - Dict, - List, - Mapping, - Optional, - Tuple, - Type, - TypeVar, - Union, -) +from typing import Any, Mapping, Optional, Type, TypeVar import torch from fsspec import AbstractFileSystem from torch import Tensor -from ..quantization import prepare_module_for_quantization -from ..quantization.bnb.config import BitsAndBytesConfig -from ..repository.fsspec import FsspecArgs, FsspecRepository -from ..repository.hf_hub import HfHubRepository -from ..repository.repository import ModelRepository, Repository -from ..util.serde import load_model_from_checkpoints -from .module import TransformerModule +from ...quantization import prepare_module_for_quantization +from ...quantization.bnb.config import BitsAndBytesConfig +from ...repository.fsspec import FsspecArgs, FsspecRepository +from ...repository.hf_hub import HfHubRepository +from ...repository.repository import ModelRepository, Repository +from ...util.serde import load_model_from_checkpoints +from ..module import TransformerModule # Only provided as typing.Self in Python 3.11+. Self = TypeVar("Self", bound="FromHFHub") @@ -38,10 +27,17 @@ class FromHFHub(ABC): """ @classmethod - @abstractmethod def convert_hf_state_dict( cls, params: Mapping[str, Tensor] ) -> Mapping[str, Tensor]: + """ + Alias for :meth:`.state_dict_from_hf`. + """ + return cls.state_dict_from_hf(params) + + @classmethod + @abstractmethod + def state_dict_from_hf(cls, params: Mapping[str, Tensor]) -> Mapping[str, Tensor]: """ Convert a state dict of a Hugging Face model to a valid state dict for the module. @@ -51,7 +47,21 @@ def convert_hf_state_dict( :returns: The converted state dict. """ - raise NotImplementedError + ... + + @classmethod + @abstractmethod + def state_dict_to_hf(cls, params: Mapping[str, Tensor]) -> Mapping[str, Tensor]: + """ + Convert the state dict of the module to a compatible + Hugging Face model's format. + + :param params: + The state dict to convert. + :returns: + The converted state dict. + """ + ... @classmethod @abstractmethod @@ -72,7 +82,7 @@ def from_hf_config( :returns: Module constructed using the configuration. """ - raise NotImplementedError + ... @classmethod def from_hf_hub_to_cache( @@ -232,53 +242,3 @@ def from_repo( model.to(device) return model - - -def _process_hf_keys( - model_name: str, - hf_config: Dict[str, Any], - hf_to_curated: Dict[str, Union[str, Tuple[str, Callable]]], - extra_keys: List[str] = [], -) -> Dict[str, Any]: - """ - Convert Hugging Face configuration keys to keyword arguments for - Curated Transformers configuration classes. - - :param model_name: - Model name. Only used in exception messages. - :param hf_config: - Hugging Face model configuration. - :param hf_to_curated: - Dictionay that maps Hugging Face configuration keys to keyword - arguments for a Curated Transformers configuration class. If a value - is a tuple, the first tuple element is the name of the keyword - argument class and the second tuple element is a conversion function. - :param extra_keys: - Optional keys for which the Hugging Face configuration key and the - keyword argument of the Curated Transformers configuration class is - the same. - :returns: - Dictionary with keyword arguments. - """ - missing_keys = tuple( - sorted(set(hf_to_curated.keys()).difference(set(hf_config.keys()))) - ) - if len(missing_keys) != 0: - raise ValueError( - f"Missing keys in Hugging Face {model_name} model config: {missing_keys}" - ) - - kwargs = {} - - for hf, curated in hf_to_curated.items(): - if isinstance(curated, tuple): - curated, ctor = curated - else: - ctor = lambda x: x - - kwargs[curated] = ctor(hf_config[hf]) - - # Handle config options that are not set in all models. - kwargs.update({k: hf_config[k] for k in extra_keys if k in hf_config}) - - return kwargs diff --git a/curated_transformers/models/llama/_hf.py b/curated_transformers/models/llama/_hf.py index 891beed5..b3cace31 100644 --- a/curated_transformers/models/llama/_hf.py +++ b/curated_transformers/models/llama/_hf.py @@ -1,17 +1,45 @@ -import re -from typing import Any, Callable, Dict, Mapping, Tuple, Union - -from torch import Tensor +from typing import Any, Callable, Dict, List, Tuple, Union from ...layers.activations import Activation -from ..hf_hub import _process_hf_keys -from ..module import DecoderModule +from ...util.string import StringTransform, StringTransformations +from ..hf_hub.conversion import process_hf_keys from .config import LlamaConfig ATTENTION_DROPOUT = "attention_probs_dropout_prob" HIDDEN_DROPOUT = "hidden_dropout_prob" EXTRA_KWARG_KEYS = [ATTENTION_DROPOUT, HIDDEN_DROPOUT] +# Order-dependent. +COMMON_HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ + # Attention blocks. + StringTransformations.sub(".self_attn", ".mha"), + StringTransformations.sub(".q_proj", ".query"), + StringTransformations.sub(".k_proj", ".key"), + StringTransformations.sub(".v_proj", ".value"), + StringTransformations.sub(".o_proj", ".output"), + # Pointwise feedforward + StringTransformations.sub(".mlp", ".ffn"), + StringTransformations.sub(".up_proj", ".intermediate"), + StringTransformations.sub("ffn.down_proj", "ffn.output"), + StringTransformations.sub(".gate_proj", ".gate"), + # RMS norms + StringTransformations.sub(".input_layernorm", ".attn_input_layer_norm"), + StringTransformations.sub(".post_attention_layernorm", ".ffn_input_layer_norm"), + StringTransformations.regex_sub( + (r"^(decoder\.)?norm\.", "\\1output_layer_norm."), + (r"^(decoder\.)?output_layer_norm\.", "\\1norm."), + ), + # Embeddings + StringTransformations.sub("embed_tokens.", "embeddings.piece_embeddings."), + StringTransformations.sub("lm_head.", "output_embeddings."), +] + +DECODER_HF_PARAM_KEY_TRANSFORMS = [ + StringTransformations.remove_prefix("model.", reversible=False) +] + COMMON_HF_PARAM_KEY_TRANSFORMS +CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS = [ + StringTransformations.sub("model.", "decoder.") +] + COMMON_HF_PARAM_KEY_TRANSFORMS HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = { "hidden_act": ("activation", Activation), @@ -25,7 +53,7 @@ def convert_hf_config(hf_config: Any) -> LlamaConfig: - kwargs = _process_hf_keys( + kwargs = process_hf_keys( "Llama", hf_config, HF_CONFIG_KEY_MAPPING, EXTRA_KWARG_KEYS ) @@ -37,43 +65,3 @@ def convert_hf_config(hf_config: Any) -> LlamaConfig: rotary_embedding_fraction=1.0, **kwargs, ) - - -def convert_hf_state_dict(cls, params: Mapping[str, Tensor]) -> Mapping[str, Tensor]: - """Convert state dict from HF paramater naming to ours. - The function is insensitive to prefixes, to allow loading - both the decoder and the full LM.""" - if issubclass(cls, DecoderModule): - stripped_params = {re.sub(r"^model\.", "", k): v for k, v in params.items()} - else: - stripped_params = { - re.sub(r"^model\.", "decoder.", k): v for k, v in params.items() - } - - out = {} - for name, parameter in stripped_params.items(): - # Attention - name = re.sub(r"\.self_attn", r".mha", name) - name = re.sub(r"\.q_proj", r".query", name) - name = re.sub(r"\.k_proj", r".key", name) - name = re.sub(r"\.v_proj", r".value", name) - name = re.sub(r"\.o_proj", r".output", name) - - # Pointwise feedforward - name = re.sub(r"\.mlp", r".ffn", name) - name = re.sub(r"\.up_proj", r".intermediate", name) - name = re.sub(r"\.down_proj", r".output", name) - name = re.sub(r"\.gate_proj", r".gate", name) - - # RMS norms - name = re.sub(r"\.input_layernorm", r".attn_input_layer_norm", name) - name = re.sub(r"\.post_attention_layernorm", r".ffn_input_layer_norm", name) - name = re.sub(r"^(decoder\.)?norm\.", r"\1output_layer_norm.", name) - - # Embeddings - name = re.sub(r"embed_tokens\.", r"embeddings.piece_embeddings.", name) - name = re.sub(r"lm_head\.", r"output_embeddings.", name) - - out[name] = parameter - - return out diff --git a/curated_transformers/models/llama/causal_lm.py b/curated_transformers/models/llama/causal_lm.py index 135178cc..aa9f5320 100644 --- a/curated_transformers/models/llama/causal_lm.py +++ b/curated_transformers/models/llama/causal_lm.py @@ -6,8 +6,9 @@ from ...quantization import Quantizable from ..hf_hub import FromHFHub +from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerCausalLM -from ._hf import convert_hf_config, convert_hf_state_dict +from ._hf import CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS, convert_hf_config from .config import LlamaConfig from .decoder import LlamaDecoder @@ -47,8 +48,16 @@ def __init__( ) @classmethod - def convert_hf_state_dict(cls, params: Mapping[str, Tensor]): - return convert_hf_state_dict(cls, params) + def state_dict_from_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_from_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS) + + @classmethod + def state_dict_to_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_to_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS) @classmethod def from_hf_config( diff --git a/curated_transformers/models/llama/decoder.py b/curated_transformers/models/llama/decoder.py index 072a2cc3..1e4b7e69 100644 --- a/curated_transformers/models/llama/decoder.py +++ b/curated_transformers/models/llama/decoder.py @@ -18,8 +18,9 @@ TransformerLayerNorms, ) from ..hf_hub import FromHFHub +from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerDecoder -from ._hf import convert_hf_config, convert_hf_state_dict +from ._hf import DECODER_HF_PARAM_KEY_TRANSFORMS, convert_hf_config from .config import LlamaConfig # Only provided as typing.Self in Python 3.11+. @@ -120,8 +121,16 @@ def __init__( ) @classmethod - def convert_hf_state_dict(cls, params: Mapping[str, Tensor]): - return convert_hf_state_dict(cls, params) + def state_dict_from_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_from_hf(params, DECODER_HF_PARAM_KEY_TRANSFORMS) + + @classmethod + def state_dict_to_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_to_hf(params, DECODER_HF_PARAM_KEY_TRANSFORMS) @classmethod def from_hf_config( diff --git a/curated_transformers/models/mpt/_hf.py b/curated_transformers/models/mpt/_hf.py index 36de5bd1..e0c1cf37 100644 --- a/curated_transformers/models/mpt/_hf.py +++ b/curated_transformers/models/mpt/_hf.py @@ -1,16 +1,38 @@ -import re -from typing import Any, Callable, Dict, Mapping, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union -from torch import Tensor, dropout, layer_norm - -from ..hf_hub import _process_hf_keys -from ..module import DecoderModule +from ...util.string import StringTransform, StringTransformations +from ..hf_hub.conversion import process_hf_keys from .config import MPTConfig ATTENTION_DROPOUT = "attention_probs_dropout_prob" HIDDEN_DROPOUT = "hidden_dropout_prob" EXTRA_KWARG_KEYS = [ATTENTION_DROPOUT, HIDDEN_DROPOUT] +# Order-dependent. +COMMON_HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ + StringTransformations.sub("transformer", "decoder"), + StringTransformations.sub("blocks", "layers"), + # Attention blocks. + StringTransformations.sub(".attn", ".mha"), + StringTransformations.sub(".Wqkv", ".input"), + StringTransformations.sub(".out_proj", ".output"), + # Pointwise feedforward. + StringTransformations.sub(".up_proj", ".intermediate"), + StringTransformations.sub("ffn.down_proj", "ffn.output"), + # Layer norms. + StringTransformations.sub(".norm_1", ".attn_input_layer_norm"), + StringTransformations.sub(".norm_2", ".ffn_input_layer_norm"), + StringTransformations.sub("norm_f.", "output_layer_norm."), + # Embeddings. + StringTransformations.sub("wte.", "embeddings.piece_embeddings."), +] + + +DECODER_HF_PARAM_KEY_TRANSFORMS = [ + StringTransformations.remove_prefix("transformer.", reversible=False) +] + COMMON_HF_PARAM_KEY_TRANSFORMS +CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS = COMMON_HF_PARAM_KEY_TRANSFORMS + HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = { "d_model": "hidden_width", "expansion_ratio": "intermediate_width_multiplier", @@ -22,7 +44,7 @@ def convert_hf_config(hf_config: Any) -> MPTConfig: - kwargs = _process_hf_keys("MPT", hf_config, HF_CONFIG_KEY_MAPPING, EXTRA_KWARG_KEYS) + kwargs = process_hf_keys("MPT", hf_config, HF_CONFIG_KEY_MAPPING, EXTRA_KWARG_KEYS) no_bias = hf_config.get("no_bias") if no_bias is None: @@ -42,49 +64,3 @@ def convert_hf_config(hf_config: Any) -> MPTConfig: layer_norm_eps=layer_norm_eps, use_bias=not no_bias, ) - - -def convert_hf_state_dict(cls, params: Mapping[str, Tensor]) -> Mapping[str, Tensor]: - """Convert state dict from HF paramater naming to ours. - The function is insensitive to prefixes, to allow loading - both the decoder and the full LM.""" - if issubclass(cls, DecoderModule): - stripped_params = { - re.sub(r"^transformer\.", "", k): v - for k, v in params.items() - # The decoder does not the output embeddings, avoid unexpected key. - if k != "lm_head.weight" - } - else: - # Rewrap as dict if necessay to make MyPy happy. - stripped_params = dict(params) - - out = {} - for name, parameter in stripped_params.items(): - # Input and output embeddings are tied in MPT. - if "lm_head" in name: - continue - - name = name.replace("transformer", "decoder") - name = name.replace("blocks", "layers") - - # Attention - name = re.sub(r"\.attn", r".mha", name) - name = re.sub(r"\.Wqkv", r".input", name) - name = re.sub(r"\.out_proj", r".output", name) - - # Pointwise feedforward - name = re.sub(r"\.up_proj", r".intermediate", name) - name = re.sub(r"\.down_proj", r".output", name) - - # Layer norms - name = re.sub(r"\.norm_1", r".attn_input_layer_norm", name) - name = re.sub(r"\.norm_2", r".ffn_input_layer_norm", name) - name = re.sub(r"norm_f\.", r"output_layer_norm.", name) - - # Embeddings - name = re.sub(r"wte\.", r"embeddings.piece_embeddings.", name) - - out[name] = parameter - - return out diff --git a/curated_transformers/models/mpt/causal_lm.py b/curated_transformers/models/mpt/causal_lm.py index 0006c064..dfc9c72a 100644 --- a/curated_transformers/models/mpt/causal_lm.py +++ b/curated_transformers/models/mpt/causal_lm.py @@ -9,9 +9,10 @@ from ...layers.cache import KeyValueCache from ...quantization import Quantizable from ..hf_hub import FromHFHub +from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..output import CausalLMOutputWithCache from ..transformer import TransformerCausalLM -from ._hf import convert_hf_config, convert_hf_state_dict +from ._hf import CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS, convert_hf_config from .config import MPTConfig from .decoder import MPTDecoder @@ -84,8 +85,16 @@ def forward( ) @classmethod - def convert_hf_state_dict(cls, params: Mapping[str, Tensor]): - return convert_hf_state_dict(cls, params) + def state_dict_from_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_from_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS) + + @classmethod + def state_dict_to_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_to_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS) @classmethod def from_hf_config( diff --git a/curated_transformers/models/mpt/decoder.py b/curated_transformers/models/mpt/decoder.py index 1207dc5c..f6acc05d 100644 --- a/curated_transformers/models/mpt/decoder.py +++ b/curated_transformers/models/mpt/decoder.py @@ -21,8 +21,9 @@ TransformerLayerNorms, ) from ..hf_hub import FromHFHub +from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerDecoder -from ._hf import convert_hf_config, convert_hf_state_dict +from ._hf import DECODER_HF_PARAM_KEY_TRANSFORMS, convert_hf_config from .config import MPTConfig # Only provided as typing.Self in Python 3.11+. @@ -120,8 +121,16 @@ def layer_norm(): self.output_layer_norm = layer_norm() @classmethod - def convert_hf_state_dict(cls, params: Mapping[str, Tensor]): - return convert_hf_state_dict(cls, params) + def state_dict_from_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_from_hf(params, DECODER_HF_PARAM_KEY_TRANSFORMS) + + @classmethod + def state_dict_to_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_to_hf(params, DECODER_HF_PARAM_KEY_TRANSFORMS) @classmethod def from_hf_config( diff --git a/curated_transformers/models/roberta/_hf.py b/curated_transformers/models/roberta/_hf.py index eed40f25..29631493 100644 --- a/curated_transformers/models/roberta/_hf.py +++ b/curated_transformers/models/roberta/_hf.py @@ -1,23 +1,56 @@ -import re -from types import MappingProxyType -from typing import Any, Callable, Dict, Mapping, Tuple, Union - -from torch import Tensor +from typing import Any, Callable, Dict, List, Tuple, Union from ...layers.activations import Activation -from ..hf_hub import _process_hf_keys +from ...util.string import StringTransform, StringTransformations +from ..hf_hub.conversion import process_hf_keys from .config import RoBERTaConfig -HF_KEY_TO_CURATED_KEY = MappingProxyType( - { - "embeddings.word_embeddings.weight": "embeddings.piece_embeddings.weight", - "embeddings.token_type_embeddings.weight": "embeddings.type_embeddings.weight", - "embeddings.position_embeddings.weight": "embeddings.position_embeddings.weight", - "embeddings.LayerNorm.weight": "embeddings.embed_output_layer_norm.weight", - "embeddings.LayerNorm.bias": "embeddings.embed_output_layer_norm.bias", - } -) - +# Order-dependent. +HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ + # Prefixes. + StringTransformations.remove_prefix("roberta.", reversible=False), + StringTransformations.regex_sub( + (r"^encoder\.(layer\.)", "\\1"), + (r"^(layer\.)", "encoder.\\1"), + ), + # Layers. + StringTransformations.regex_sub((r"^layer", "layers"), (r"^layers", "layer")), + # Attention blocks. + StringTransformations.regex_sub( + (r"\.attention\.self\.(query|key|value)", ".mha.\\1"), + (r"\.mha\.(query|key|value)", ".attention.self.\\1"), + ), + StringTransformations.sub(".attention.output.dense", ".mha.output"), + StringTransformations.sub( + r".attention.output.LayerNorm", ".attn_residual_layer_norm" + ), + # Pointwise feed-forward layers. + StringTransformations.sub(".intermediate.dense", ".ffn.intermediate"), + StringTransformations.regex_sub( + (r"(\.\d+)\.output\.LayerNorm", "\\1.ffn_residual_layer_norm"), + (r"(\.\d+)\.ffn_residual_layer_norm", "\\1.output.LayerNorm"), + ), + StringTransformations.regex_sub( + (r"(\.\d+)\.output\.dense", "\\1.ffn.output"), + (r"(\.\d+)\.ffn\.output", "\\1.output.dense"), + ), + # Embeddings. + StringTransformations.replace( + "embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight" + ), + StringTransformations.replace( + "embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight" + ), + StringTransformations.replace( + "embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight" + ), + StringTransformations.replace( + "embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight" + ), + StringTransformations.replace( + "embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias" + ), +] HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = { "pad_token_id": "padding_id", @@ -36,7 +69,7 @@ def convert_hf_config(hf_config: Any) -> RoBERTaConfig: - kwargs = _process_hf_keys("RoBERTa", hf_config, HF_CONFIG_KEY_MAPPING) + kwargs = process_hf_keys("RoBERTa", hf_config, HF_CONFIG_KEY_MAPPING) return RoBERTaConfig( embedding_width=hf_config["hidden_size"], @@ -45,41 +78,3 @@ def convert_hf_config(hf_config: Any) -> RoBERTaConfig: - (kwargs["padding_id"] + 1), **kwargs, ) - - -def convert_hf_state_dict(params: Mapping[str, Tensor]) -> Mapping[str, Tensor]: - out = {} - - # Strip the `roberta` prefix from XLM-RoBERTa model parameters. - stripped_params = {re.sub(r"^roberta\.", "", k): v for k, v in params.items()} - - for name, parameter in stripped_params.items(): - if "encoder.layer." not in name: - continue - - # TODO: Make these substitutions less ugly. - - # Remove the prefix and rename the internal 'layers' variable. - name = re.sub(r"^encoder\.", "", name) - name = re.sub(r"^layer", "layers", name) - - # The HF model has one more level of indirection for the output layers in their - # attention heads and the feed-forward network layers. - name = re.sub(r"\.attention\.self\.(query|key|value)", r".mha.\1", name) - name = re.sub(r"\.attention\.(output)\.dense", r".mha.\1", name) - name = re.sub( - r"\.attention\.output\.LayerNorm", r".attn_residual_layer_norm", name - ) - name = re.sub(r"\.(intermediate)\.dense", r".ffn.\1", name) - name = re.sub( - r"(\.\d+)\.output\.LayerNorm", r"\1.ffn_residual_layer_norm", name - ) - name = re.sub(r"(\.\d+)\.(output)\.dense", r"\1.ffn.\2", name) - - out[name] = parameter - - for hf_name, curated_name in HF_KEY_TO_CURATED_KEY.items(): - if hf_name in stripped_params: - out[curated_name] = stripped_params[hf_name] - - return out diff --git a/curated_transformers/models/roberta/encoder.py b/curated_transformers/models/roberta/encoder.py index 11b06c1d..b9cfc3f3 100644 --- a/curated_transformers/models/roberta/encoder.py +++ b/curated_transformers/models/roberta/encoder.py @@ -15,8 +15,9 @@ TransformerLayerNorms, ) from ..hf_hub import FromHFHub +from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerEncoder -from ._hf import convert_hf_config, convert_hf_state_dict +from ._hf import HF_PARAM_KEY_TRANSFORMS, convert_hf_config from .config import RoBERTaConfig from .embeddings import RoBERTaEmbeddings @@ -105,8 +106,16 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No ) @classmethod - def convert_hf_state_dict(cls, params: Mapping[str, Tensor]): - return convert_hf_state_dict(params) + def state_dict_from_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_from_hf(params, HF_PARAM_KEY_TRANSFORMS) + + @classmethod + def state_dict_to_hf( + cls: Type[Self], params: Mapping[str, Tensor] + ) -> Mapping[str, Tensor]: + return state_dict_to_hf(params, HF_PARAM_KEY_TRANSFORMS) @classmethod def from_hf_config( diff --git a/curated_transformers/tests/models/albert/test_encoder.py b/curated_transformers/tests/models/albert/test_encoder.py index da88c504..6be65b27 100644 --- a/curated_transformers/tests/models/albert/test_encoder.py +++ b/curated_transformers/tests/models/albert/test_encoder.py @@ -4,7 +4,11 @@ from ...compat import has_hf_transformers, has_torch_compile from ...conftest import TORCH_DEVICES -from ..util import JITMethod, assert_encoder_output_equals_hf +from ..util import ( + JITMethod, + assert_encoder_output_equals_hf, + assert_model_hf_serialization_roundtrip, +) def test_rejects_incorrect_number_of_groups(): @@ -52,3 +56,11 @@ def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp): jit_method=JITMethod.TorchScriptTrace, with_torch_sdp=with_torch_sdp, ) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_encoder_hf_serializtion_roundtrip(torch_device): + assert_model_hf_serialization_roundtrip( + ALBERTEncoder, "explosion-testing/albert-test", torch_device + ) diff --git a/curated_transformers/tests/models/bert/test_encoder.py b/curated_transformers/tests/models/bert/test_encoder.py index 61083070..3736ab09 100644 --- a/curated_transformers/tests/models/bert/test_encoder.py +++ b/curated_transformers/tests/models/bert/test_encoder.py @@ -4,7 +4,11 @@ from ...compat import has_hf_transformers, has_torch_compile from ...conftest import TORCH_DEVICES -from ..util import JITMethod, assert_encoder_output_equals_hf +from ..util import ( + JITMethod, + assert_encoder_output_equals_hf, + assert_model_hf_serialization_roundtrip, +) @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @@ -46,3 +50,11 @@ def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp): jit_method=JITMethod.TorchScriptTrace, with_torch_sdp=with_torch_sdp, ) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_encoder_hf_serializtion_roundtrip(torch_device): + assert_model_hf_serialization_roundtrip( + BERTEncoder, "explosion-testing/bert-test", torch_device + ) diff --git a/curated_transformers/tests/models/camembert/test_encoder.py b/curated_transformers/tests/models/camembert/test_encoder.py index b9c10ea8..0eeb6e73 100644 --- a/curated_transformers/tests/models/camembert/test_encoder.py +++ b/curated_transformers/tests/models/camembert/test_encoder.py @@ -4,7 +4,11 @@ from ...compat import has_hf_transformers, has_torch_compile from ...conftest import TORCH_DEVICES -from ..util import JITMethod, assert_encoder_output_equals_hf +from ..util import ( + JITMethod, + assert_encoder_output_equals_hf, + assert_model_hf_serialization_roundtrip, +) @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @@ -46,3 +50,11 @@ def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp): jit_method=JITMethod.TorchScriptTrace, with_torch_sdp=with_torch_sdp, ) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_encoder_hf_serializtion_roundtrip(torch_device): + assert_model_hf_serialization_roundtrip( + CamemBERTEncoder, "explosion-testing/camembert-test", torch_device + ) diff --git a/curated_transformers/tests/models/falcon/test_decoder.py b/curated_transformers/tests/models/falcon/test_decoder.py index a4ee9b7f..1cc03263 100644 --- a/curated_transformers/tests/models/falcon/test_decoder.py +++ b/curated_transformers/tests/models/falcon/test_decoder.py @@ -7,7 +7,11 @@ from ...compat import has_hf_transformers, has_torch_compile from ...conftest import TORCH_DEVICES from ...utils import torch_assertclose -from ..util import JITMethod, assert_decoder_output_equals_hf +from ..util import ( + JITMethod, + assert_decoder_output_equals_hf, + assert_model_hf_serialization_roundtrip, +) N_PIECES = 1024 @@ -142,3 +146,17 @@ def test_decoder_with_cache(torch_device, model_revision): ).last_hidden_layer_state torch_assertclose(Y, Y_no_cache[:, 10:, :]) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +@pytest.mark.parametrize("model_revision", FALCON_TEST_MODELS) +def test_decoder_hf_serializtion_roundtrip(torch_device, model_revision): + model, revision = model_revision + assert_model_hf_serialization_roundtrip( + FalconDecoder, + model, + torch_device, + model_revision=revision, + trust_remote_code=True, + ) diff --git a/curated_transformers/tests/models/gpt_neox/test_causal_lm.py b/curated_transformers/tests/models/gpt_neox/test_causal_lm.py index 6630094c..6c8a47e2 100644 --- a/curated_transformers/tests/models/gpt_neox/test_causal_lm.py +++ b/curated_transformers/tests/models/gpt_neox/test_causal_lm.py @@ -4,7 +4,11 @@ from ...compat import has_hf_transformers, has_torch_compile from ...conftest import TORCH_DEVICES -from ..util import JITMethod, assert_causal_lm_output_equals_hf +from ..util import ( + JITMethod, + assert_causal_lm_output_equals_hf, + assert_model_hf_serialization_roundtrip, +) @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @@ -46,3 +50,13 @@ def test_causal_lm_with_torchscript_trace(torch_device, with_torch_sdp): jit_method=JITMethod.TorchScriptTrace, with_torch_sdp=with_torch_sdp, ) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_causal_lm_hf_serializtion_roundtrip(torch_device): + assert_model_hf_serialization_roundtrip( + GPTNeoXCausalLM, + "trl-internal-testing/tiny-random-GPTNeoXForCausalLM", + torch_device, + ) diff --git a/curated_transformers/tests/models/gpt_neox/test_decoder.py b/curated_transformers/tests/models/gpt_neox/test_decoder.py index 6a8ad351..81dee8a1 100644 --- a/curated_transformers/tests/models/gpt_neox/test_decoder.py +++ b/curated_transformers/tests/models/gpt_neox/test_decoder.py @@ -4,7 +4,11 @@ from ...compat import has_hf_transformers, has_torch_compile from ...conftest import TORCH_DEVICES -from ..util import JITMethod, assert_decoder_output_equals_hf +from ..util import ( + JITMethod, + assert_decoder_output_equals_hf, + assert_model_hf_serialization_roundtrip, +) @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @@ -54,3 +58,13 @@ def test_decoder_with_torchscript_trace(torch_device, with_torch_sdp): jit_method=JITMethod.TorchScriptTrace, with_torch_sdp=with_torch_sdp, ) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_decoder_hf_serializtion_roundtrip(torch_device): + assert_model_hf_serialization_roundtrip( + GPTNeoXDecoder, + "trl-internal-testing/tiny-random-GPTNeoXForCausalLM", + torch_device, + ) diff --git a/curated_transformers/tests/models/llama/test_causal_lm.py b/curated_transformers/tests/models/llama/test_causal_lm.py index 4efbc26d..54f8eb56 100644 --- a/curated_transformers/tests/models/llama/test_causal_lm.py +++ b/curated_transformers/tests/models/llama/test_causal_lm.py @@ -4,7 +4,11 @@ from ...compat import has_hf_transformers, has_torch_compile from ...conftest import TORCH_DEVICES -from ..util import JITMethod, assert_causal_lm_output_equals_hf +from ..util import ( + JITMethod, + assert_causal_lm_output_equals_hf, + assert_model_hf_serialization_roundtrip, +) LLAMA_TEST_MODELS = [ "trl-internal-testing/tiny-random-LlamaForCausalLM", @@ -55,3 +59,10 @@ def test_causal_lm_with_torchscript_trace(torch_device, model, with_torch_sdp): jit_method=JITMethod.TorchScriptTrace, with_torch_sdp=with_torch_sdp, ) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("model", LLAMA_TEST_MODELS) +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_causal_lm_hf_serializtion_roundtrip(model, torch_device): + assert_model_hf_serialization_roundtrip(LlamaCausalLM, model, torch_device) diff --git a/curated_transformers/tests/models/llama/test_decoder.py b/curated_transformers/tests/models/llama/test_decoder.py index 8cffc81b..105bc3f4 100644 --- a/curated_transformers/tests/models/llama/test_decoder.py +++ b/curated_transformers/tests/models/llama/test_decoder.py @@ -4,7 +4,11 @@ from ...compat import has_hf_transformers, has_torch_compile from ...conftest import TORCH_DEVICES -from ..util import JITMethod, assert_decoder_output_equals_hf +from ..util import ( + JITMethod, + assert_decoder_output_equals_hf, + assert_model_hf_serialization_roundtrip, +) LLAMA_TEST_MODELS = [ "trl-internal-testing/tiny-random-LlamaForCausalLM", @@ -52,3 +56,10 @@ def test_decoder_with_torchscript_trace(torch_device, model, with_torch_sdp): jit_method=JITMethod.TorchScriptTrace, with_torch_sdp=with_torch_sdp, ) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("model", LLAMA_TEST_MODELS) +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_decoder_hf_serializtion_roundtrip(model, torch_device): + assert_model_hf_serialization_roundtrip(LlamaDecoder, model, torch_device) diff --git a/curated_transformers/tests/models/mpt/test_causal_lm.py b/curated_transformers/tests/models/mpt/test_causal_lm.py index 5c2b1e1d..b3c41bd8 100644 --- a/curated_transformers/tests/models/mpt/test_causal_lm.py +++ b/curated_transformers/tests/models/mpt/test_causal_lm.py @@ -4,7 +4,11 @@ from ...compat import has_hf_transformers, has_torch_compile from ...conftest import TORCH_DEVICES -from ..util import JITMethod, assert_causal_lm_output_equals_hf +from ..util import ( + JITMethod, + assert_causal_lm_output_equals_hf, + assert_model_hf_serialization_roundtrip, +) @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @@ -46,3 +50,11 @@ def test_causal_lm_with_torchscript_trace(torch_device, with_torch_sdp): jit_method=JITMethod.TorchScriptTrace, with_torch_sdp=with_torch_sdp, ) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_causal_lm_hf_serializtion_roundtrip(torch_device): + assert_model_hf_serialization_roundtrip( + MPTCausalLM, "explosion-testing/mpt-test", torch_device + ) diff --git a/curated_transformers/tests/models/mpt/test_decoder.py b/curated_transformers/tests/models/mpt/test_decoder.py index 627cfe6e..69ffc7f5 100644 --- a/curated_transformers/tests/models/mpt/test_decoder.py +++ b/curated_transformers/tests/models/mpt/test_decoder.py @@ -4,7 +4,11 @@ from ...compat import has_hf_transformers, has_torch_compile from ...conftest import TORCH_DEVICES -from ..util import JITMethod, assert_decoder_output_equals_hf +from ..util import ( + JITMethod, + assert_decoder_output_equals_hf, + assert_model_hf_serialization_roundtrip, +) @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @@ -56,3 +60,11 @@ def test_decoder_with_torchscript_trace(torch_device, with_torch_sdp): jit_method=JITMethod.TorchScriptTrace, with_torch_sdp=with_torch_sdp, ) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_decoder_hf_serializtion_roundtrip(torch_device): + assert_model_hf_serialization_roundtrip( + MPTDecoder, "explosion-testing/mpt-test", torch_device + ) diff --git a/curated_transformers/tests/models/roberta/test_encoder.py b/curated_transformers/tests/models/roberta/test_encoder.py index dfa6b1d4..73739fb1 100644 --- a/curated_transformers/tests/models/roberta/test_encoder.py +++ b/curated_transformers/tests/models/roberta/test_encoder.py @@ -4,7 +4,11 @@ from ...compat import has_hf_transformers, has_torch_compile from ...conftest import TORCH_DEVICES -from ..util import JITMethod, assert_encoder_output_equals_hf +from ..util import ( + JITMethod, + assert_encoder_output_equals_hf, + assert_model_hf_serialization_roundtrip, +) @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @@ -46,3 +50,11 @@ def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp): jit_method=JITMethod.TorchScriptTrace, with_torch_sdp=with_torch_sdp, ) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_encoder_hf_serializtion_roundtrip(torch_device): + assert_model_hf_serialization_roundtrip( + RoBERTaEncoder, "explosion-testing/roberta-test", torch_device + ) diff --git a/curated_transformers/tests/models/util.py b/curated_transformers/tests/models/util.py index d9a12926..3f7913bf 100644 --- a/curated_transformers/tests/models/util.py +++ b/curated_transformers/tests/models/util.py @@ -364,3 +364,44 @@ def assert_model_config(model: TransformerModule, model_output: Tensor): hidden_width = model_output.size(-1) assert config.layer.feedforward.hidden_width == hidden_width + + +def assert_model_hf_serialization_roundtrip( + model_class: Type[FromHFHub], + model_name: str, + torch_device: torch.device, + *, + model_revision: str = "main", + atol: float = 1e-5, + rtol: float = 1e-5, + trust_remote_code: bool = False, +): + orig_model = model_class.from_hf_hub( + name=model_name, + revision=model_revision, + device=torch_device, + ) + orig_model.eval() + + for _, param in orig_model.state_dict().items(): + assert param.device == torch_device + + auto_cls = ( + transformers.AutoModelForCausalLM + if isinstance(orig_model, CausalLMModule) + else transformers.AutoModel + ) + + hf_model = auto_cls.from_pretrained( + model_name, + revision=model_revision, + trust_remote_code=trust_remote_code, + ) + hf_model.to(torch_device) + hf_model.eval() + + hf_model_statedict = hf_model.state_dict() + orig_model_hf_statedict = orig_model.state_dict_to_hf(orig_model.state_dict()) + for name in orig_model_hf_statedict.keys(): + assert name in hf_model_statedict.keys(), f"{name} not found in HF state dict" + torch_assertclose(orig_model_hf_statedict[name], hf_model_statedict[name]) diff --git a/curated_transformers/tests/models/xlm_roberta/test_encoder.py b/curated_transformers/tests/models/xlm_roberta/test_encoder.py index 3c4ef124..1fc3773e 100644 --- a/curated_transformers/tests/models/xlm_roberta/test_encoder.py +++ b/curated_transformers/tests/models/xlm_roberta/test_encoder.py @@ -4,7 +4,11 @@ from ...compat import has_hf_transformers, has_torch_compile from ...conftest import TORCH_DEVICES -from ..util import JITMethod, assert_encoder_output_equals_hf +from ..util import ( + JITMethod, + assert_encoder_output_equals_hf, + assert_model_hf_serialization_roundtrip, +) @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @@ -46,3 +50,11 @@ def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp): jit_method=JITMethod.TorchScriptTrace, with_torch_sdp=with_torch_sdp, ) + + +@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") +@pytest.mark.parametrize("torch_device", TORCH_DEVICES) +def test_encoder_hf_serializtion_roundtrip(torch_device): + assert_model_hf_serialization_roundtrip( + XLMREncoder, "explosion-testing/xlm-roberta-test", torch_device + ) diff --git a/curated_transformers/util/string.py b/curated_transformers/util/string.py new file mode 100644 index 00000000..d2d5c183 --- /dev/null +++ b/curated_transformers/util/string.py @@ -0,0 +1,249 @@ +import re +from abc import ABC, abstractmethod +from typing import Optional, Tuple + + +class StringTransform(ABC): + """ + Base class for reversible string transformations. + """ + + def __init__(self, reversible: bool = True): + super().__init__() + self._reversible = reversible + + @abstractmethod + def _apply(self, string: str) -> str: + raise NotImplementedError + + @abstractmethod + def _revert(self, string: str) -> str: + raise NotImplementedError + + def apply(self, string: str) -> str: + """ + Applies the transformation to the given string. + + :param string: + String to transform. + :returns: + Transformed string. + """ + return self._apply(string) + + def revert(self, string: str) -> str: + """ + Reverts the previously applied transformation of the given string. + + :param string: + Previously transformed string. + :returns: + Reverted string. + """ + if self._reversible: + return self._revert(string) + else: + return string + + +class StringTransformations: + """ + Provides factory methods for different string transformations. + """ + + @staticmethod + def regex_sub( + forward: Tuple[str, str], backward: Optional[Tuple[str, str]] + ) -> StringTransform: + """ + Factory method to construct a string substitution transform + using regular expressions. + + :param forward: + Tuple where the first string is a RegEx pattern + and the second the replacement. + + This operation is performed when the :meth:`.apply` + method is invoked. + :param backward: + Optional tuple where the first string is a RegEx pattern + and the second the replacement. + + This operation is performed when the :meth:`.revert` + method is invoked. If ``None``, it is a no-op. + """ + return StringSubRegEx(forward, backward) + + @staticmethod + def sub( + substring: str, replacement: str, *, reversible: bool = True + ) -> StringTransform: + """ + Factory method to construct a string substitution transform. + + :param substring: + The substring to be replaced. + :param replacement: + The replacement string. + :param reversible: + If the reverse transformation is to + be performed. + """ + return StringSub(substring, replacement, reversible=reversible) + + @staticmethod + def replace( + replacee: str, replacement: str, *, reversible: bool = True + ) -> StringTransform: + """ + Factory method to construct a full string replacement transform. + + :param replacee: + The full string to be replaced. + :param replacement: + The replacement string. + :param reversible: + If the reverse transformation is to + be performed. + """ + return StringReplace(replacee, replacement, reversible=reversible) + + @staticmethod + def remove_prefix(prefix: str, *, reversible: bool = True) -> StringTransform: + """ + Factory method to construct a string prefix removal transform. + + :param prefix: + Prefix to be removed. + :param reversible: + If the reverse transformation is to + be performed. + """ + return StringRemovePrefix(prefix, reversible=reversible) + + +class StringSubRegEx(StringTransform): + """ + Substitute a substring with another string using + regular expressions. + """ + + def __init__(self, forward: Tuple[str, str], backward: Optional[Tuple[str, str]]): + """ + Construct a reversible substitution. + + :param forward: + Tuple where the first string is a RegEx pattern + and the second the replacement. + + This operation is performed when the :meth:`.apply` + method is invoked. + :param backward: + Optional tuple where the first string is a RegEx pattern + and the second the replacement. + + This operation is performed when the :meth:`.revert` + method is invoked. If ``None``, it is a no-op. + """ + super().__init__(backward is not None) + + self.forward = forward + self.backward = backward + + def _apply(self, string: str) -> str: + return re.sub(self.forward[0], self.forward[1], string) + + def _revert(self, string: str) -> str: + if self.backward is None: + raise ValueError("Attempting to revert an irreversible string transform") + return re.sub(self.backward[0], self.backward[1], string) + + +class StringSub(StringTransform): + """ + Substitute a substring with another string. + """ + + def __init__(self, substring: str, replacement: str, *, reversible: bool = True): + """ + Construct a reversible substitution. + + :param substring: + The substring to be replaced. + :param replacement: + The replacement string. + :param reversible: + If the reverse transformation is to + be performed. + """ + super().__init__(reversible) + self.substring = substring + self.replacement = replacement + + def _apply(self, string: str) -> str: + return string.replace(self.substring, self.replacement) + + def _revert(self, string: str) -> str: + return string.replace(self.replacement, self.substring) + + +class StringReplace(StringTransform): + """ + Replaces an entire string with another. + """ + + def __init__(self, replacee: str, replacement: str, *, reversible: bool = True): + """ + Construct a reversible replacement. + + :param replacee: + The full string to be replaced. + :param replacement: + The replacement string. + :param reversible: + If the reverse transformation is to + be performed. + """ + super().__init__(reversible) + self.replacee = replacee + self.replacement = replacement + + def _apply(self, string: str) -> str: + if string == self.replacee: + return self.replacement + else: + return string + + def _revert(self, string: str) -> str: + if string == self.replacement: + return self.replacee + else: + return string + + +class StringRemovePrefix(StringTransform): + """ + Strips a prefix from a given string. + """ + + def __init__(self, prefix: str, *, reversible: bool = True): + """ + Construct a reversible left strip. + + :param prefix: + Prefix to be removed. + :param reversible: + If the reverse transformation is to + be performed. + """ + + super().__init__(reversible) + self.prefix = prefix + + def _apply(self, string: str) -> str: + # TODO: Should be replaced with `removeprefix` once + # Python 3.9 is the minimum requirement. + return re.sub(f"^{re.escape(self.prefix)}", "", string) + + def _revert(self, string: str) -> str: + return f"{self.prefix}{string}" diff --git a/docs/source/api-compat.rst b/docs/source/api-compat.rst index f1dee1fb..0b753ee8 100644 --- a/docs/source/api-compat.rst +++ b/docs/source/api-compat.rst @@ -110,3 +110,5 @@ Version 1 to 2 * The factory methods of :py:class:`~curated_transformers.layers.AttentionHeads` add a new ``qkv_split`` argument which is mandatory in future versions. * The ``FromHFHub`` mixins will be renamed to ``FromHF``. +* The ``convert_hf_state_dict`` method in ``FromHFHub`` will be removed + in favour of ``state_dict_from_hf``.