diff --git a/curated_transformers/models/bert/encoder.py b/curated_transformers/models/bert/encoder.py index 5d924304..c32034bb 100644 --- a/curated_transformers/models/bert/encoder.py +++ b/curated_transformers/models/bert/encoder.py @@ -12,13 +12,14 @@ class BertEncoder(Module): - def __init__( - self, - config: BertConfig, - ): + def __init__(self, config: BertConfig, embeddings: Optional[Module] = None): super().__init__() - self.embeddings = BertEmbeddings(config.embedding, config.layer) + if embeddings is None: + self.embeddings = BertEmbeddings(config.embedding, config.layer) + else: + self.embeddings = embeddings + self.padding_idx = config.padding_idx self.max_seq_len = config.model_max_length self.layers = torch.nn.ModuleList( diff --git a/curated_transformers/models/embeddings.py b/curated_transformers/models/embeddings.py index 94d3155f..c381e89f 100644 --- a/curated_transformers/models/embeddings.py +++ b/curated_transformers/models/embeddings.py @@ -6,7 +6,7 @@ # https://pytorch.org/tutorials/beginner/transformer_tutorial.html class SinusoidalPositionalEmbedding(Module): - def __init__(self, dim: int, max_len: int, *, normalize=True): + def __init__(self, dim: int, max_len: int, *, normalize: bool = True): super().__init__() position = torch.arange(max_len).unsqueeze(1) @@ -16,16 +16,16 @@ def __init__(self, dim: int, max_len: int, *, normalize=True): pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) - if normalize == True: - l2 = torch.norm(pe, dim=-1) + if normalize: + l2 = torch.linalg.vector_norm(pe, dim=-1) pe /= l2.unsqueeze(-1) - self.pe = pe - self.pe.requires_grad = False + pe.requires_grad = False + self.register_buffer("pe", pe, persistent=False) - def forward(self, x: Tensor) -> Tensor: + def forward(self, input: Tensor) -> Tensor: """ Shapes: x - (batch, seq_len) """ - return self.pe[x.size(1), :] + return self.pe[: input.size(1), :] diff --git a/curated_transformers/models/hf_util.py b/curated_transformers/models/hf_util.py index 19be7ebe..94ab2262 100644 --- a/curated_transformers/models/hf_util.py +++ b/curated_transformers/models/hf_util.py @@ -143,12 +143,14 @@ def _convert_albert_base_state( out["embeddings.word_embeddings.weight"] = stripped_params[ "embeddings.word_embeddings.weight" ] - out["embeddings.token_type_embeddings.weight"] = stripped_params[ - "embeddings.token_type_embeddings.weight" - ] - out["embeddings.position_embeddings.weight"] = stripped_params[ - "embeddings.position_embeddings.weight" - ] + if "embeddings.token_type_embeddings.weight" in stripped_params: + out["embeddings.token_type_embeddings.weight"] = stripped_params[ + "embeddings.token_type_embeddings.weight" + ] + if "embeddings.position_embeddings.weight" in stripped_params: + out["embeddings.position_embeddings.weight"] = stripped_params[ + "embeddings.position_embeddings.weight" + ] out["embeddings.layer_norm.weight"] = stripped_params["embeddings.LayerNorm.weight"] out["embeddings.layer_norm.bias"] = stripped_params["embeddings.LayerNorm.bias"] @@ -198,12 +200,14 @@ def _convert_bert_base_state( out["embeddings.word_embeddings.weight"] = stripped_params[ "embeddings.word_embeddings.weight" ] - out["embeddings.token_type_embeddings.weight"] = stripped_params[ - "embeddings.token_type_embeddings.weight" - ] - out["embeddings.position_embeddings.weight"] = stripped_params[ - "embeddings.position_embeddings.weight" - ] + if "embeddings.token_type_embeddings.weight" in stripped_params: + out["embeddings.token_type_embeddings.weight"] = stripped_params[ + "embeddings.token_type_embeddings.weight" + ] + if "embeddings.position_embeddings.weight" in stripped_params: + out["embeddings.position_embeddings.weight"] = stripped_params[ + "embeddings.position_embeddings.weight" + ] out["embeddings.layer_norm.weight"] = stripped_params["embeddings.LayerNorm.weight"] out["embeddings.layer_norm.bias"] = stripped_params["embeddings.LayerNorm.bias"] @@ -245,12 +249,14 @@ def _convert_roberta_base_state( out["embeddings.inner.word_embeddings.weight"] = stripped_params[ "embeddings.word_embeddings.weight" ] - out["embeddings.inner.token_type_embeddings.weight"] = stripped_params[ - "embeddings.token_type_embeddings.weight" - ] - out["embeddings.inner.position_embeddings.weight"] = stripped_params[ - "embeddings.position_embeddings.weight" - ] + if "embeddings.token_type_embeddings.weight" in stripped_params: + out["embeddings.inner.token_type_embeddings.weight"] = stripped_params[ + "embeddings.token_type_embeddings.weight" + ] + if "embeddings.position_embeddings.weight" in stripped_params: + out["embeddings.inner.position_embeddings.weight"] = stripped_params[ + "embeddings.position_embeddings.weight" + ] out["embeddings.inner.layer_norm.weight"] = stripped_params[ "embeddings.LayerNorm.weight" ] diff --git a/curated_transformers/models/roberta/encoder.py b/curated_transformers/models/roberta/encoder.py index ca829854..143db994 100644 --- a/curated_transformers/models/roberta/encoder.py +++ b/curated_transformers/models/roberta/encoder.py @@ -12,12 +12,16 @@ class RobertaEncoder(Module): - def __init__(self, config: RobertaConfig): + def __init__(self, config: RobertaConfig, embeddings: Optional[Module] = None): super().__init__() - self.embeddings = RobertaEmbeddings( - config.embedding, config.layer, padding_idx=config.padding_idx - ) + if embeddings is None: + self.embeddings = RobertaEmbeddings( + config.embedding, config.layer, padding_idx=config.padding_idx + ) + else: + self.embeddings = embeddings + self.padding_idx = config.padding_idx self.max_seq_len = config.model_max_length self.layers = torch.nn.ModuleList( diff --git a/curated_transformers/models/sinusoidal/__init__.py b/curated_transformers/models/sinusoidal/__init__.py new file mode 100644 index 00000000..4621ce1e --- /dev/null +++ b/curated_transformers/models/sinusoidal/__init__.py @@ -0,0 +1 @@ +from .embeddings import SinusoidalEmbeddings diff --git a/curated_transformers/models/sinusoidal/embeddings.py b/curated_transformers/models/sinusoidal/embeddings.py new file mode 100644 index 00000000..de1d9885 --- /dev/null +++ b/curated_transformers/models/sinusoidal/embeddings.py @@ -0,0 +1,45 @@ +import torch +from typing import Optional +from ..bert.config import BertEmbeddingConfig +from ..embeddings import SinusoidalPositionalEmbedding + +from torch import Tensor +from torch.nn import Module, Embedding, Dropout, LayerNorm + + +class SinusoidalEmbeddings(Module): + def __init__(self, embedding_config: BertEmbeddingConfig): + super().__init__() + + self.word_embeddings = Embedding( + num_embeddings=embedding_config.vocab_size, + embedding_dim=embedding_config.embedding_width, + ) + self.layer_norm = LayerNorm( + embedding_config.embedding_width, eps=embedding_config.layer_norm_eps + ) + self.dropout = Dropout(p=embedding_config.dropout_prob) + + self.sinusoidal = SinusoidalPositionalEmbedding( + dim=embedding_config.embedding_width, max_len=10000 + ) + + def forward( + self, + input_ids: Tensor, + token_type_ids: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + ) -> Tensor: + """ + Shapes: + input_ids, token_type_ids, position_ids - (batch, seq_len) + """ + embeddings = self.word_embeddings(input_ids.long()) + _, seq_len, dim = embeddings.shape + position_embeddings = self.sinusoidal(embeddings) + with torch.no_grad(): + embeddings += position_embeddings + + embeddings = self.layer_norm(embeddings) + + return self.dropout(embeddings)