Skip to content

Commit

Permalink
Add support for converting Curated Transfomer state dicts Hugging Fac…
Browse files Browse the repository at this point in the history
…e compatible state dicts (#332)

* Add support for converting Curated Transfomer state dicts to
Hugging Face compatible state dicts.

* Typo

* Expand string transform class names

* Raise error during an unexpected reverse substitution

* `StringReplace` without `re`

* Add note about `removeprefix`

* Concat removed prefix

* Add `StringTransformations` to expose factory methods for
`StringTransform` subclasses

* Simplify `StringSubInvertible` and renanme to `StringSub`
  • Loading branch information
shadeMe committed Sep 27, 2023
1 parent 1f023dc commit 93cf07f
Show file tree
Hide file tree
Showing 36 changed files with 1,004 additions and 521 deletions.
119 changes: 50 additions & 69 deletions curated_transformers/models/albert/_hf.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,56 @@
import re
from types import MappingProxyType
from typing import Any, Callable, Dict, Mapping, Tuple, Union

from torch import Tensor
from typing import Any, Callable, Dict, List, Tuple, Union

from ...layers.activations import Activation
from ..hf_hub import _process_hf_keys
from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import process_hf_keys
from .config import ALBERTConfig

HF_KEY_TO_CURATED_KEY = MappingProxyType(
{
"embeddings.word_embeddings.weight": "embeddings.piece_embeddings.weight",
"embeddings.token_type_embeddings.weight": "embeddings.type_embeddings.weight",
"embeddings.position_embeddings.weight": "embeddings.position_embeddings.weight",
"embeddings.LayerNorm.weight": "embeddings.embed_output_layer_norm.weight",
"embeddings.LayerNorm.bias": "embeddings.embed_output_layer_norm.bias",
# Embedding projection
"encoder.embedding_hidden_mapping_in.weight": "embeddings.projection.weight",
"encoder.embedding_hidden_mapping_in.bias": "embeddings.projection.bias",
}
)
# Order-dependent.
HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [
# Prefixes.
StringTransformations.remove_prefix("albert.", reversible=False),
StringTransformations.regex_sub(
(r"^encoder\.(embedding_|albert_layer)", "\\1"),
(r"^(embedding_|albert_layer)", "encoder.\\1"),
),
# Layer groups
StringTransformations.regex_sub(
(r"^albert_layer_groups\.", "groups."), (r"^groups\.", "albert_layer_groups.")
),
# Inner layers.
StringTransformations.sub(".albert_layers.", ".group_layers."),
# Attention blocks.
StringTransformations.sub(".attention.", ".mha."),
StringTransformations.sub(".mha.LayerNorm", ".attn_residual_layer_norm"),
StringTransformations.sub(".mha.dense", ".mha.output"),
# Pointwise feed-forward layers.
StringTransformations.sub(".ffn.", ".ffn.intermediate."),
StringTransformations.sub(".ffn_output.", ".ffn.output."),
StringTransformations.sub(".full_layer_layer_norm.", ".ffn_residual_layer_norm."),
# Embeddings.
StringTransformations.replace(
"embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight"
),
StringTransformations.replace(
"embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight"
),
StringTransformations.replace(
"embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"
),
StringTransformations.replace(
"embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight"
),
StringTransformations.replace(
"embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias"
),
# Embedding projection.
StringTransformations.replace(
"embedding_hidden_mapping_in.weight", "embeddings.projection.weight"
),
StringTransformations.replace(
"embedding_hidden_mapping_in.bias", "embeddings.projection.bias"
),
]

HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = {
"attention_probs_dropout_prob": "attention_probs_dropout_prob",
Expand All @@ -40,55 +71,5 @@


def convert_hf_config(hf_config: Any) -> ALBERTConfig:
kwargs = _process_hf_keys("ALBERT", hf_config, HF_CONFIG_KEY_MAPPING)
kwargs = process_hf_keys("ALBERT", hf_config, HF_CONFIG_KEY_MAPPING)
return ALBERTConfig(model_max_length=hf_config["max_position_embeddings"], **kwargs)


def convert_hf_state_dict(params: Mapping[str, Tensor]) -> Mapping[str, Tensor]:
# Strip the `albert` prefix from ALBERT model parameters.
stripped_params = {re.sub(r"^albert\.", "", k): v for k, v in params.items()}

# The ALBERT encoder parameters have the following form:
#
# encoder.albert_layer_groups.{hidden_group}.albert_layers.{inner_layer}.{param_name}
#
# hidden_group is in [0, n_hidden_group)
# inner_layer is in [0, n_layers_per_group)

out = {}
for name, parameter in stripped_params.items():
if "encoder.albert_layer" not in name:
continue

# TODO: Make these substitutions less ugly.

# Remove the prefix and rename.
name = re.sub(r"^encoder\.", "", name)

# Layer groups
name = re.sub(r"^albert_layer_groups\.", "groups.", name)

# Inner layers.
name = re.sub(r"\.albert_layers\.", ".group_layers.", name)

# Attention blocks.
name = re.sub(r"\.attention\.", ".mha.", name)
name = re.sub(r"\.mha\.LayerNorm", r".attn_residual_layer_norm", name)
name = re.sub(r"\.mha\.dense\.", r".mha.output.", name)

# Pointwise feed-forward layers.
name = re.sub(r"\.ffn\.", r".ffn.intermediate.", name)
name = re.sub(r"\.ffn_output\.", r".ffn.output.", name)
name = re.sub(
r"\.full_layer_layer_norm\.",
r".ffn_residual_layer_norm.",
name,
)

out[name] = parameter

for hf_name, curated_name in HF_KEY_TO_CURATED_KEY.items():
if hf_name in stripped_params:
out[curated_name] = stripped_params[hf_name]

return out
15 changes: 12 additions & 3 deletions curated_transformers/models/albert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
TransformerEmbeddings,
)
from ..hf_hub import FromHFHub
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..module import EncoderModule
from ..output import ModelOutput
from ._hf import convert_hf_config, convert_hf_state_dict
from ._hf import HF_PARAM_KEY_TRANSFORMS, convert_hf_config
from .config import ALBERTConfig
from .layer_group import ALBERTLayerGroup

Expand Down Expand Up @@ -99,8 +100,16 @@ def forward(
return ModelOutput(all_outputs=[embeddings, *layer_outputs])

@classmethod
def convert_hf_state_dict(cls, params: Mapping[str, Tensor]):
return convert_hf_state_dict(params)
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
) -> Mapping[str, Tensor]:
return state_dict_from_hf(params, HF_PARAM_KEY_TRANSFORMS)

@classmethod
def state_dict_to_hf(
cls: Type[Self], params: Mapping[str, Tensor]
) -> Mapping[str, Tensor]:
return state_dict_to_hf(params, HF_PARAM_KEY_TRANSFORMS)

@classmethod
def from_hf_config(
Expand Down
121 changes: 53 additions & 68 deletions curated_transformers/models/bert/_hf.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,59 @@
import re
from types import MappingProxyType
from typing import Any, Callable, Dict, Mapping, Tuple, Union

from torch import Tensor
from typing import Any, Callable, Dict, List, Tuple, Union

from ...layers.activations import Activation
from ..hf_hub import _process_hf_keys
from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import process_hf_keys
from .config import BERTConfig

HF_KEY_TO_CURATED_KEY = MappingProxyType(
{
"embeddings.word_embeddings.weight": "embeddings.piece_embeddings.weight",
"embeddings.token_type_embeddings.weight": "embeddings.type_embeddings.weight",
"embeddings.position_embeddings.weight": "embeddings.position_embeddings.weight",
"embeddings.LayerNorm.weight": "embeddings.embed_output_layer_norm.weight",
"embeddings.LayerNorm.bias": "embeddings.embed_output_layer_norm.bias",
}
)

# Order-dependent.
HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [
# Old HF parameter names (one-way transforms).
StringTransformations.regex_sub((r"\.gamma$", ".weight"), backward=None),
StringTransformations.regex_sub((r"\.beta$", ".bias"), backward=None),
# Prefixes.
StringTransformations.remove_prefix("bert.", reversible=False),
StringTransformations.regex_sub(
(r"^encoder\.(layer\.)", "\\1"),
(r"^(layer\.)", "encoder.\\1"),
),
# Layers.
StringTransformations.regex_sub((r"^layer", "layers"), (r"^layers", "layer")),
# Attention blocks.
StringTransformations.regex_sub(
(r"\.attention\.self\.(query|key|value)", ".mha.\\1"),
(r"\.mha\.(query|key|value)", ".attention.self.\\1"),
),
StringTransformations.sub(".attention.output.dense", ".mha.output"),
StringTransformations.sub(
r".attention.output.LayerNorm", ".attn_residual_layer_norm"
),
# Pointwise feed-forward layers.
StringTransformations.sub(".intermediate.dense", ".ffn.intermediate"),
StringTransformations.regex_sub(
(r"(\.\d+)\.output\.LayerNorm", "\\1.ffn_residual_layer_norm"),
(r"(\.\d+)\.ffn_residual_layer_norm", "\\1.output.LayerNorm"),
),
StringTransformations.regex_sub(
(r"(\.\d+)\.output\.dense", "\\1.ffn.output"),
(r"(\.\d+)\.ffn\.output", "\\1.output.dense"),
),
# Embeddings.
StringTransformations.replace(
"embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight"
),
StringTransformations.replace(
"embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight"
),
StringTransformations.replace(
"embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"
),
StringTransformations.replace(
"embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight"
),
StringTransformations.replace(
"embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias"
),
]

HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = {
"attention_probs_dropout_prob": "attention_probs_dropout_prob",
Expand All @@ -35,61 +71,10 @@


def convert_hf_config(hf_config: Any) -> BERTConfig:
kwargs = _process_hf_keys("BERT", hf_config, HF_CONFIG_KEY_MAPPING)
kwargs = process_hf_keys("BERT", hf_config, HF_CONFIG_KEY_MAPPING)

return BERTConfig(
embedding_width=hf_config["hidden_size"],
model_max_length=hf_config["max_position_embeddings"],
**kwargs,
)


def convert_hf_state_dict(params: Mapping[str, Tensor]) -> Mapping[str, Tensor]:
out = {}

renamed_params = _rename_old_hf_names(params)

# Strip the `bert` prefix from BERT model parameters.
stripped_params = {re.sub(r"^bert\.", "", k): v for k, v in renamed_params.items()}

for name, parameter in stripped_params.items():
if "encoder.layer." not in name:
continue

# TODO: Make these substitutions less ugly.

# Remove the prefix and rename the internal 'layers' variable.
name = re.sub(r"^encoder\.", "", name)
name = re.sub(r"^layer", "layers", name)

# The HF model has one more level of indirection for the output layers in their
# attention heads and the feed-forward network layers.
name = re.sub(r"\.attention\.self\.(query|key|value)", r".mha.\1", name)
name = re.sub(r"\.attention\.(output)\.dense", r".mha.\1", name)
name = re.sub(
r"\.attention\.output\.LayerNorm", r".attn_residual_layer_norm", name
)
name = re.sub(r"\.(intermediate)\.dense", r".ffn.\1", name)
name = re.sub(
r"(\.\d+)\.output\.LayerNorm", r"\1.ffn_residual_layer_norm", name
)
name = re.sub(r"(\.\d+)\.(output)\.dense", r"\1.ffn.\2", name)

out[name] = parameter

for hf_name, curated_name in HF_KEY_TO_CURATED_KEY.items():
if hf_name in stripped_params:
out[curated_name] = stripped_params[hf_name]

return out


def _rename_old_hf_names(
params: Mapping[str, Tensor],
) -> Mapping[str, Tensor]:
out = {}
for name, parameter in params.items():
name = re.sub(r"\.gamma$", ".weight", name)
name = re.sub(r"\.beta$", ".bias", name)
out[name] = parameter
return out
15 changes: 12 additions & 3 deletions curated_transformers/models/bert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
TransformerLayerNorms,
)
from ..hf_hub import FromHFHub
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..transformer import TransformerEncoder
from ._hf import convert_hf_config, convert_hf_state_dict
from ._hf import HF_PARAM_KEY_TRANSFORMS, convert_hf_config
from .config import BERTConfig

# Only provided as typing.Self in Python 3.11+.
Expand Down Expand Up @@ -105,8 +106,16 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None)
)

@classmethod
def convert_hf_state_dict(cls, params: Mapping[str, Tensor]):
return convert_hf_state_dict(params)
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
) -> Mapping[str, Tensor]:
return state_dict_from_hf(params, HF_PARAM_KEY_TRANSFORMS)

@classmethod
def state_dict_to_hf(
cls: Type[Self], params: Mapping[str, Tensor]
) -> Mapping[str, Tensor]:
return state_dict_to_hf(params, HF_PARAM_KEY_TRANSFORMS)

@classmethod
def from_hf_config(
Expand Down
Loading

0 comments on commit 93cf07f

Please sign in to comment.