Skip to content

Commit

Permalink
Add basic support for BERT/RoBERTa with sinusoidal embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Oct 13, 2023
1 parent f2bca8f commit 9ec3371
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 34 deletions.
11 changes: 6 additions & 5 deletions curated_transformers/models/bert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@


class BertEncoder(Module):
def __init__(
self,
config: BertConfig,
):
def __init__(self, config: BertConfig, embeddings: Optional[Module] = None):
super().__init__()

self.embeddings = BertEmbeddings(config.embedding, config.layer)
if embeddings is None:
self.embeddings = BertEmbeddings(config.embedding, config.layer)
else:
self.embeddings = embeddings

self.padding_idx = config.padding_idx
self.max_seq_len = config.model_max_length
self.layers = torch.nn.ModuleList(
Expand Down
14 changes: 7 additions & 7 deletions curated_transformers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class SinusoidalPositionalEmbedding(Module):
def __init__(self, dim: int, max_len: int, *, normalize=True):
def __init__(self, dim: int, max_len: int, *, normalize: bool = True):
super().__init__()

position = torch.arange(max_len).unsqueeze(1)
Expand All @@ -16,16 +16,16 @@ def __init__(self, dim: int, max_len: int, *, normalize=True):
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

if normalize == True:
l2 = torch.norm(pe, dim=-1)
if normalize:
l2 = torch.linalg.vector_norm(pe, dim=-1)
pe /= l2.unsqueeze(-1)

self.pe = pe
self.pe.requires_grad = False
pe.requires_grad = False
self.register_buffer("pe", pe, persistent=False)

def forward(self, x: Tensor) -> Tensor:
def forward(self, input: Tensor) -> Tensor:
"""
Shapes:
x - (batch, seq_len)
"""
return self.pe[x.size(1), :]
return self.pe[: input.size(1), :]
42 changes: 24 additions & 18 deletions curated_transformers/models/hf_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,14 @@ def _convert_albert_base_state(
out["embeddings.word_embeddings.weight"] = stripped_params[
"embeddings.word_embeddings.weight"
]
out["embeddings.token_type_embeddings.weight"] = stripped_params[
"embeddings.token_type_embeddings.weight"
]
out["embeddings.position_embeddings.weight"] = stripped_params[
"embeddings.position_embeddings.weight"
]
if "embeddings.token_type_embeddings.weight" in stripped_params:
out["embeddings.token_type_embeddings.weight"] = stripped_params[
"embeddings.token_type_embeddings.weight"
]
if "embeddings.position_embeddings.weight" in stripped_params:
out["embeddings.position_embeddings.weight"] = stripped_params[
"embeddings.position_embeddings.weight"
]
out["embeddings.layer_norm.weight"] = stripped_params["embeddings.LayerNorm.weight"]
out["embeddings.layer_norm.bias"] = stripped_params["embeddings.LayerNorm.bias"]

Expand Down Expand Up @@ -198,12 +200,14 @@ def _convert_bert_base_state(
out["embeddings.word_embeddings.weight"] = stripped_params[
"embeddings.word_embeddings.weight"
]
out["embeddings.token_type_embeddings.weight"] = stripped_params[
"embeddings.token_type_embeddings.weight"
]
out["embeddings.position_embeddings.weight"] = stripped_params[
"embeddings.position_embeddings.weight"
]
if "embeddings.token_type_embeddings.weight" in stripped_params:
out["embeddings.token_type_embeddings.weight"] = stripped_params[
"embeddings.token_type_embeddings.weight"
]
if "embeddings.position_embeddings.weight" in stripped_params:
out["embeddings.position_embeddings.weight"] = stripped_params[
"embeddings.position_embeddings.weight"
]
out["embeddings.layer_norm.weight"] = stripped_params["embeddings.LayerNorm.weight"]
out["embeddings.layer_norm.bias"] = stripped_params["embeddings.LayerNorm.bias"]

Expand Down Expand Up @@ -245,12 +249,14 @@ def _convert_roberta_base_state(
out["embeddings.inner.word_embeddings.weight"] = stripped_params[
"embeddings.word_embeddings.weight"
]
out["embeddings.inner.token_type_embeddings.weight"] = stripped_params[
"embeddings.token_type_embeddings.weight"
]
out["embeddings.inner.position_embeddings.weight"] = stripped_params[
"embeddings.position_embeddings.weight"
]
if "embeddings.token_type_embeddings.weight" in stripped_params:
out["embeddings.inner.token_type_embeddings.weight"] = stripped_params[
"embeddings.token_type_embeddings.weight"
]
if "embeddings.position_embeddings.weight" in stripped_params:
out["embeddings.inner.position_embeddings.weight"] = stripped_params[
"embeddings.position_embeddings.weight"
]
out["embeddings.inner.layer_norm.weight"] = stripped_params[
"embeddings.LayerNorm.weight"
]
Expand Down
12 changes: 8 additions & 4 deletions curated_transformers/models/roberta/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@


class RobertaEncoder(Module):
def __init__(self, config: RobertaConfig):
def __init__(self, config: RobertaConfig, embeddings: Optional[Module] = None):
super().__init__()

self.embeddings = RobertaEmbeddings(
config.embedding, config.layer, padding_idx=config.padding_idx
)
if embeddings is None:
self.embeddings = RobertaEmbeddings(
config.embedding, config.layer, padding_idx=config.padding_idx
)
else:
self.embeddings = embeddings

self.padding_idx = config.padding_idx
self.max_seq_len = config.model_max_length
self.layers = torch.nn.ModuleList(
Expand Down
1 change: 1 addition & 0 deletions curated_transformers/models/sinusoidal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .embeddings import SinusoidalEmbeddings
45 changes: 45 additions & 0 deletions curated_transformers/models/sinusoidal/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from typing import Optional
from ..bert.config import BertEmbeddingConfig
from ..embeddings import SinusoidalPositionalEmbedding

from torch import Tensor
from torch.nn import Module, Embedding, Dropout, LayerNorm


class SinusoidalEmbeddings(Module):
def __init__(self, embedding_config: BertEmbeddingConfig):
super().__init__()

self.word_embeddings = Embedding(
num_embeddings=embedding_config.vocab_size,
embedding_dim=embedding_config.embedding_width,
)
self.layer_norm = LayerNorm(
embedding_config.embedding_width, eps=embedding_config.layer_norm_eps
)
self.dropout = Dropout(p=embedding_config.dropout_prob)

self.sinusoidal = SinusoidalPositionalEmbedding(
dim=embedding_config.embedding_width, max_len=10000
)

def forward(
self,
input_ids: Tensor,
token_type_ids: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
) -> Tensor:
"""
Shapes:
input_ids, token_type_ids, position_ids - (batch, seq_len)
"""
embeddings = self.word_embeddings(input_ids.long())
_, seq_len, dim = embeddings.shape
position_embeddings = self.sinusoidal(embeddings)
with torch.no_grad():
embeddings += position_embeddings

embeddings = self.layer_norm(embeddings)

return self.dropout(embeddings)

0 comments on commit 9ec3371

Please sign in to comment.