diff --git a/curated_transformers/layers/__init__.py b/curated_transformers/layers/__init__.py index c52fa093..ef0d8be1 100644 --- a/curated_transformers/layers/__init__.py +++ b/curated_transformers/layers/__init__.py @@ -4,6 +4,7 @@ AttentionHeads, AttentionLinearBiases, AttentionMask, + AttentionScorer, QkvMode, QkvSplit, ScaledDotProductAttention, @@ -32,6 +33,7 @@ "AttentionHeads", "AttentionLinearBiases", "AttentionMask", + "AttentionScorer", "CacheProtocol", "DecoderLayer", "EmbeddingDropouts", diff --git a/curated_transformers/layers/attention.py b/curated_transformers/layers/attention.py index 2fba63ad..cfe5aec4 100644 --- a/curated_transformers/layers/attention.py +++ b/curated_transformers/layers/attention.py @@ -9,7 +9,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.nn import Linear, Module +from torch.nn import Dropout, Linear, Module from ..semver import Default, FutureMandatory from ..util.dataclass import DataclassAsDict @@ -633,34 +633,12 @@ def forward(self, *, attention_scores: Tensor, inplace: bool = True) -> Tensor: return attention_scores + biases -class ScaledDotProductAttention(Module): +class AttentionScorer(Module, ABC): """ - Scaled dot-product attention (`Vaswani et al., 2017`_). - - .. _Vaswani et al., 2017: https://arxiv.org/abs/1706.03762 + Base class of attention scoring implementations. """ - linear_biases: Optional[AttentionLinearBiases] - - def __init__( - self, *, dropout_prob: float, linear_biases: Optional[AttentionLinearBiases] - ): - """ - Construct a scaled dot-product attention module. - - :param dropout_prob: - Dropout to apply to the final hidden representation. - :param linear_biases: - ALiBi (`Press et al., 2022`_) for attention scores. - Not applied if ``None``. - - .. _Press et al., 2022: https://arxiv.org/abs/2108.12409 - """ - super().__init__() - - self.dropout = torch.nn.Dropout(p=dropout_prob) - self.linear_biases = linear_biases - + @abstractmethod def forward( self, *, @@ -670,9 +648,9 @@ def forward( attention_mask: AttentionMask, ) -> Tensor: """ - Apply attention layer to the given key, query and value. + Apply attention scores to the given key, query and value. - Sequence elements that are marked with `False` in the attention mask + Sequence elements that are marked with ``False`` in the attention mask are ignored by the attention mechanism (if a mask is provided). :param query: @@ -696,18 +674,78 @@ def forward( *Shape:* ``(batch_size, heads, seq_len, width)`` """ - model_width = key.shape[-1] - attn_scores = query @ key.transpose(-2, -1) - attn_scores /= math.sqrt(model_width) + ... + + +class ScaledDotProductAttention(AttentionScorer): + """ + Scaled dot-product attention (`Vaswani et al., 2017`_). + + .. _Vaswani et al., 2017: https://arxiv.org/abs/1706.03762 + """ + + linear_biases: Optional[AttentionLinearBiases] + + def __init__( + self, *, dropout_prob: float, linear_biases: Optional[AttentionLinearBiases] + ): + """ + Construct a scaled dot-product attention module. + + :param dropout_prob: + Dropout to apply to the final hidden representation. + :param linear_biases: + ALiBi (`Press et al., 2022`_) for attention scores. + Not applied if ``None``. + + .. _Press et al., 2022: https://arxiv.org/abs/2108.12409 + """ + super().__init__() - if self.linear_biases is not None: - attn_scores = self.linear_biases(attention_scores=attn_scores) + self.dropout = Dropout(p=dropout_prob) + self.linear_biases = linear_biases + + def forward( + self, + *, + query: Tensor, + key: Tensor, + value: Tensor, + attention_mask: AttentionMask, + ) -> Tensor: + if _TORCH_SDP.get(): + attn_mask = attention_mask.logit_mask(query.dtype) + + # Add AliBi to the logit mask + if self.linear_biases is not None: + biases = self.linear_biases.calculate_biases(key.size(-2)).to( + dtype=query.dtype, device=query.device + ) + bool_mask = attention_mask.bool_mask + attn_mask = torch.where(bool_mask, biases, attn_mask) + + # We can't pass a bool mask, because it is currently broken: + # https://github.com/pytorch/pytorch/issues/103749 + return F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + ) + else: + width = key.shape[-1] + attn_scores = query @ key.transpose(-2, -1) + attn_scores /= math.sqrt(width) + + if self.linear_biases is not None: + attn_scores = self.linear_biases(attention_scores=attn_scores) - attn_scores = attention_mask.apply_logit_mask(attn_scores) - attn_weights = attn_scores.softmax(dim=-1) - attn_values = self.dropout(attn_weights @ value) + attn_scores = attention_mask.apply_logit_mask(attn_scores) + attn_weights = attn_scores.softmax(dim=-1) + attn_values = self.dropout(attn_weights @ value) - return attn_values + return attn_values class SelfAttention(Module): @@ -723,11 +761,10 @@ def __init__( self, *, attention_heads: AttentionHeads, - dropout_prob: float, + attention_scorer: AttentionScorer, hidden_width: int, qkv_mode: QkvMode, rotary_embeds: Optional[QueryKeyRotaryEmbeddings] = None, - attention_biases: Optional[AttentionLinearBiases] = None, use_bias: bool, device: Optional[torch.device] = None, ): @@ -737,12 +774,10 @@ def __init__( :param attention_heads: Attention head configuration. - :param dropout_prob: - Dropout to apply between the self-attention and output layers. + :param attention_scorer: + Attention scorer used to calculate the attention values. :param hidden_width: Hidden width of the layer. - :param attention_biases: - ALiBi biases. ALiBi will not be used when set to ``None``. :param qkv_mode: Handling mode for query, key and value. :param rotary_embeds: @@ -756,7 +791,6 @@ def __init__( super().__init__() - self.dropout_prob = dropout_prob self.attention_heads = attention_heads if hidden_width % attention_heads._n_query_heads != 0: raise ValueError( @@ -766,13 +800,10 @@ def __init__( self.head_width = hidden_width // attention_heads._n_query_heads self.qkv_mode = qkv_mode - self.use_alibi = attention_biases is not None self.rotary_embeds = rotary_embeds - self.attention = ScaledDotProductAttention( - dropout_prob=dropout_prob, linear_biases=attention_biases - ) + self.attention_scorer = attention_scorer if ( qkv_mode == QkvMode.MERGED_SPLIT_BEFORE @@ -877,34 +908,12 @@ def forward( causal_mask = create_causal_mask(query, key) combined_mask = combined_mask.merge_mask(causal_mask) - if _TORCH_SDP.get(): - attn_mask = combined_mask.logit_mask(query.dtype) - - # Add AliBi to the logit mask - if self.use_alibi: - assert self.attention.linear_biases is not None - biases = self.attention.linear_biases.calculate_biases(key.size(-2)).to( - dtype=query.dtype, device=query.device - ) - bool_mask = combined_mask.bool_mask - attn_mask = torch.where(bool_mask, biases, attn_mask) - - # We can't pass a bool mask, because it is currently broken: - # https://github.com/pytorch/pytorch/issues/103749 - attn = F.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=self.dropout_prob if self.training else 0.0, - ) - else: - attn = self.attention( - query=query, - key=key, - value=value, - attention_mask=combined_mask, - ) + attn = self.attention_scorer( + query=query, + key=key, + value=value, + attention_mask=combined_mask, + ) attn = combine_heads(attn) diff --git a/curated_transformers/models/albert/layer_group.py b/curated_transformers/models/albert/layer_group.py index d0beae41..874e6c04 100644 --- a/curated_transformers/models/albert/layer_group.py +++ b/curated_transformers/models/albert/layer_group.py @@ -5,7 +5,13 @@ from torch import Tensor from torch.nn import LayerNorm, Module, ModuleList -from ...layers.attention import AttentionHeads, AttentionMask, QkvMode, SelfAttention +from ...layers.attention import ( + AttentionHeads, + AttentionMask, + QkvMode, + ScaledDotProductAttention, + SelfAttention, +) from ...layers.feedforward import PointwiseFeedForward from ...layers.transformer import ( EncoderLayer, @@ -41,7 +47,10 @@ def __init__( attention_heads=AttentionHeads.uniform( attention_config.n_query_heads ), - dropout_prob=attention_config.dropout_prob, + attention_scorer=ScaledDotProductAttention( + dropout_prob=attention_config.dropout_prob, + linear_biases=None, + ), hidden_width=layer_config.feedforward.hidden_width, qkv_mode=QkvMode.SEPARATE, rotary_embeds=None, diff --git a/curated_transformers/models/bert/encoder.py b/curated_transformers/models/bert/encoder.py index 363b2875..c9b4c6be 100644 --- a/curated_transformers/models/bert/encoder.py +++ b/curated_transformers/models/bert/encoder.py @@ -5,7 +5,12 @@ from torch import Tensor from torch.nn import Dropout, LayerNorm -from ...layers.attention import AttentionHeads, QkvMode, SelfAttention +from ...layers.attention import ( + AttentionHeads, + QkvMode, + ScaledDotProductAttention, + SelfAttention, +) from ...layers.feedforward import PointwiseFeedForward from ...layers.transformer import ( EmbeddingDropouts, @@ -77,7 +82,10 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None) attention_heads=AttentionHeads.uniform( config.layer.attention.n_query_heads ), - dropout_prob=config.layer.attention.dropout_prob, + attention_scorer=ScaledDotProductAttention( + dropout_prob=config.layer.attention.dropout_prob, + linear_biases=None, + ), hidden_width=config.layer.feedforward.hidden_width, qkv_mode=QkvMode.SEPARATE, rotary_embeds=None, diff --git a/curated_transformers/models/falcon/decoder.py b/curated_transformers/models/falcon/decoder.py index 478a70d0..5e08d928 100644 --- a/curated_transformers/models/falcon/decoder.py +++ b/curated_transformers/models/falcon/decoder.py @@ -3,12 +3,13 @@ import torch from torch import Tensor -from torch.nn import Dropout, Embedding, LayerNorm, ModuleList +from torch.nn import Dropout, LayerNorm, ModuleList from ...layers.attention import ( AttentionHeads, AttentionLinearBiases, QkvMode, + ScaledDotProductAttention, SelfAttention, ) from ...layers.embeddings import QueryKeyRotaryEmbeddings @@ -150,12 +151,14 @@ def _create_new_decoder_architecture_layer( ) return DecoderLayer( attention_layer=SelfAttention( - attention_biases=attention_biases, attention_heads=AttentionHeads.key_value_broadcast( n_query_heads=n_attention_heads, n_key_value_heads=config.layer.attention.n_key_value_heads, ), - dropout_prob=config.layer.attention.dropout_prob, + attention_scorer=ScaledDotProductAttention( + dropout_prob=config.layer.attention.dropout_prob, + linear_biases=attention_biases, + ), hidden_width=hidden_width, qkv_mode=QkvMode.MERGED_SPLIT_AFTER, rotary_embeds=rotary_embeds, diff --git a/curated_transformers/models/falcon/layer.py b/curated_transformers/models/falcon/layer.py index 003156dc..88d30135 100644 --- a/curated_transformers/models/falcon/layer.py +++ b/curated_transformers/models/falcon/layer.py @@ -10,6 +10,7 @@ AttentionMask, KeyValueCache, QkvMode, + ScaledDotProductAttention, SelfAttention, ) from ...layers.embeddings import QueryKeyRotaryEmbeddings @@ -62,8 +63,10 @@ def __init__( else None ) self.mha = SelfAttention( - attention_biases=attention_biases, - dropout_prob=attention_config.dropout_prob, + attention_scorer=ScaledDotProductAttention( + dropout_prob=attention_config.dropout_prob, + linear_biases=attention_biases, + ), hidden_width=hidden_width, attention_heads=AttentionHeads.key_value_broadcast( n_query_heads=attention_config.n_query_heads, diff --git a/curated_transformers/models/gpt_neox/decoder.py b/curated_transformers/models/gpt_neox/decoder.py index c7eb634a..37f3e237 100644 --- a/curated_transformers/models/gpt_neox/decoder.py +++ b/curated_transformers/models/gpt_neox/decoder.py @@ -5,7 +5,12 @@ from torch import Tensor from torch.nn import Dropout, LayerNorm, ModuleList -from ...layers.attention import AttentionHeads, QkvMode, SelfAttention +from ...layers.attention import ( + AttentionHeads, + QkvMode, + ScaledDotProductAttention, + SelfAttention, +) from ...layers.embeddings import QueryKeyRotaryEmbeddings from ...layers.feedforward import PointwiseFeedForward from ...layers.transformer import ( @@ -78,7 +83,10 @@ def __init__( DecoderLayer( attention_layer=SelfAttention( attention_heads=AttentionHeads.uniform(n_attention_heads), - dropout_prob=config.layer.attention.dropout_prob, + attention_scorer=ScaledDotProductAttention( + dropout_prob=config.layer.attention.dropout_prob, + linear_biases=None, + ), hidden_width=hidden_width, qkv_mode=QkvMode.MERGED_SPLIT_BEFORE, rotary_embeds=QueryKeyRotaryEmbeddings( diff --git a/curated_transformers/models/llama/decoder.py b/curated_transformers/models/llama/decoder.py index 1e4b7e69..0fb58644 100644 --- a/curated_transformers/models/llama/decoder.py +++ b/curated_transformers/models/llama/decoder.py @@ -5,7 +5,12 @@ from torch import Tensor from torch.nn import Dropout, ModuleList -from ...layers.attention import AttentionHeads, QkvMode, SelfAttention +from ...layers.attention import ( + AttentionHeads, + QkvMode, + ScaledDotProductAttention, + SelfAttention, +) from ...layers.embeddings import QueryKeyRotaryEmbeddings from ...layers.feedforward import PointwiseFeedForward from ...layers.normalization import RMSNorm @@ -84,7 +89,10 @@ def __init__( DecoderLayer( attention_layer=SelfAttention( attention_heads=attention_heads, - dropout_prob=config.layer.attention.dropout_prob, + attention_scorer=ScaledDotProductAttention( + dropout_prob=config.layer.attention.dropout_prob, + linear_biases=None, + ), hidden_width=hidden_width, qkv_mode=QkvMode.SEPARATE, rotary_embeds=QueryKeyRotaryEmbeddings( diff --git a/curated_transformers/models/mpt/decoder.py b/curated_transformers/models/mpt/decoder.py index f6acc05d..5ee41909 100644 --- a/curated_transformers/models/mpt/decoder.py +++ b/curated_transformers/models/mpt/decoder.py @@ -9,6 +9,7 @@ AttentionLinearBiases, QkvMode, QkvSplitGroupedByHead, + ScaledDotProductAttention, SelfAttention, ) from ...layers.feedforward import PointwiseFeedForward @@ -86,11 +87,13 @@ def layer_norm(): [ DecoderLayer( attention_layer=SelfAttention( - attention_biases=attention_biases, attention_heads=AttentionHeads.uniform( n_attention_heads, QkvSplitGroupedByHead() ), - dropout_prob=config.layer.attention.dropout_prob, + attention_scorer=ScaledDotProductAttention( + dropout_prob=config.layer.attention.dropout_prob, + linear_biases=attention_biases, + ), hidden_width=hidden_width, qkv_mode=QkvMode.MERGED_SPLIT_AFTER, rotary_embeds=None, diff --git a/curated_transformers/models/roberta/encoder.py b/curated_transformers/models/roberta/encoder.py index b9cfc3f3..88d4317c 100644 --- a/curated_transformers/models/roberta/encoder.py +++ b/curated_transformers/models/roberta/encoder.py @@ -5,7 +5,12 @@ from torch import Tensor from torch.nn import Dropout, LayerNorm -from ...layers.attention import AttentionHeads, QkvMode, SelfAttention +from ...layers.attention import ( + AttentionHeads, + QkvMode, + ScaledDotProductAttention, + SelfAttention, +) from ...layers.feedforward import PointwiseFeedForward from ...layers.transformer import ( EmbeddingDropouts, @@ -77,7 +82,10 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No attention_heads=AttentionHeads.uniform( config.layer.attention.n_query_heads ), - dropout_prob=config.layer.attention.dropout_prob, + attention_scorer=ScaledDotProductAttention( + dropout_prob=config.layer.attention.dropout_prob, + linear_biases=None, + ), hidden_width=hidden_width, qkv_mode=QkvMode.SEPARATE, rotary_embeds=None, diff --git a/docs/source/api-compat.rst b/docs/source/api-compat.rst index 0b753ee8..65580524 100644 --- a/docs/source/api-compat.rst +++ b/docs/source/api-compat.rst @@ -112,3 +112,4 @@ Version 1 to 2 * The ``FromHFHub`` mixins will be renamed to ``FromHF``. * The ``convert_hf_state_dict`` method in ``FromHFHub`` will be removed in favour of ``state_dict_from_hf``. +* The ``SelfAttention`` class will take an additional ``AttentionScorer`` argument. \ No newline at end of file diff --git a/docs/source/building-blocks.rst b/docs/source/building-blocks.rst index 566f1797..b854ca6d 100644 --- a/docs/source/building-blocks.rst +++ b/docs/source/building-blocks.rst @@ -55,6 +55,9 @@ These modules and their helper classes implement the Transformer attention mecha .. autoclass:: curated_transformers.layers.KeyValueCache :members: +.. autoclass:: curated_transformers.layers.AttentionScorer + :members: + .. autoclass:: curated_transformers.layers.AttentionLinearBiases :members: :show-inheritance: