Skip to content

Commit

Permalink
Add AttentionScorer abstraction (#349)
Browse files Browse the repository at this point in the history
* Add `AttentionScorer` abstraction

* Remove Torch SDP warning
  • Loading branch information
shadeMe committed Oct 5, 2023
1 parent 84a37f0 commit 499e9d9
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 93 deletions.
2 changes: 2 additions & 0 deletions curated_transformers/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AttentionHeads,
AttentionLinearBiases,
AttentionMask,
AttentionScorer,
QkvMode,
QkvSplit,
ScaledDotProductAttention,
Expand Down Expand Up @@ -32,6 +33,7 @@
"AttentionHeads",
"AttentionLinearBiases",
"AttentionMask",
"AttentionScorer",
"CacheProtocol",
"DecoderLayer",
"EmbeddingDropouts",
Expand Down
161 changes: 85 additions & 76 deletions curated_transformers/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn import Dropout, Linear, Module

from ..semver import Default, FutureMandatory
from ..util.dataclass import DataclassAsDict
Expand Down Expand Up @@ -633,34 +633,12 @@ def forward(self, *, attention_scores: Tensor, inplace: bool = True) -> Tensor:
return attention_scores + biases


class ScaledDotProductAttention(Module):
class AttentionScorer(Module, ABC):
"""
Scaled dot-product attention (`Vaswani et al., 2017`_).
.. _Vaswani et al., 2017: https://arxiv.org/abs/1706.03762
Base class of attention scoring implementations.
"""

linear_biases: Optional[AttentionLinearBiases]

def __init__(
self, *, dropout_prob: float, linear_biases: Optional[AttentionLinearBiases]
):
"""
Construct a scaled dot-product attention module.
:param dropout_prob:
Dropout to apply to the final hidden representation.
:param linear_biases:
ALiBi (`Press et al., 2022`_) for attention scores.
Not applied if ``None``.
.. _Press et al., 2022: https://arxiv.org/abs/2108.12409
"""
super().__init__()

self.dropout = torch.nn.Dropout(p=dropout_prob)
self.linear_biases = linear_biases

@abstractmethod
def forward(
self,
*,
Expand All @@ -670,9 +648,9 @@ def forward(
attention_mask: AttentionMask,
) -> Tensor:
"""
Apply attention layer to the given key, query and value.
Apply attention scores to the given key, query and value.
Sequence elements that are marked with `False` in the attention mask
Sequence elements that are marked with ``False`` in the attention mask
are ignored by the attention mechanism (if a mask is provided).
:param query:
Expand All @@ -696,18 +674,78 @@ def forward(
*Shape:* ``(batch_size, heads, seq_len, width)``
"""
model_width = key.shape[-1]
attn_scores = query @ key.transpose(-2, -1)
attn_scores /= math.sqrt(model_width)
...


class ScaledDotProductAttention(AttentionScorer):
"""
Scaled dot-product attention (`Vaswani et al., 2017`_).
.. _Vaswani et al., 2017: https://arxiv.org/abs/1706.03762
"""

linear_biases: Optional[AttentionLinearBiases]

def __init__(
self, *, dropout_prob: float, linear_biases: Optional[AttentionLinearBiases]
):
"""
Construct a scaled dot-product attention module.
:param dropout_prob:
Dropout to apply to the final hidden representation.
:param linear_biases:
ALiBi (`Press et al., 2022`_) for attention scores.
Not applied if ``None``.
.. _Press et al., 2022: https://arxiv.org/abs/2108.12409
"""
super().__init__()

if self.linear_biases is not None:
attn_scores = self.linear_biases(attention_scores=attn_scores)
self.dropout = Dropout(p=dropout_prob)
self.linear_biases = linear_biases

def forward(
self,
*,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: AttentionMask,
) -> Tensor:
if _TORCH_SDP.get():
attn_mask = attention_mask.logit_mask(query.dtype)

# Add AliBi to the logit mask
if self.linear_biases is not None:
biases = self.linear_biases.calculate_biases(key.size(-2)).to(
dtype=query.dtype, device=query.device
)
bool_mask = attention_mask.bool_mask
attn_mask = torch.where(bool_mask, biases, attn_mask)

# We can't pass a bool mask, because it is currently broken:
# https://github.com/pytorch/pytorch/issues/103749
return F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
)
else:
width = key.shape[-1]
attn_scores = query @ key.transpose(-2, -1)
attn_scores /= math.sqrt(width)

if self.linear_biases is not None:
attn_scores = self.linear_biases(attention_scores=attn_scores)

attn_scores = attention_mask.apply_logit_mask(attn_scores)
attn_weights = attn_scores.softmax(dim=-1)
attn_values = self.dropout(attn_weights @ value)
attn_scores = attention_mask.apply_logit_mask(attn_scores)
attn_weights = attn_scores.softmax(dim=-1)
attn_values = self.dropout(attn_weights @ value)

return attn_values
return attn_values


class SelfAttention(Module):
Expand All @@ -723,11 +761,10 @@ def __init__(
self,
*,
attention_heads: AttentionHeads,
dropout_prob: float,
attention_scorer: AttentionScorer,
hidden_width: int,
qkv_mode: QkvMode,
rotary_embeds: Optional[QueryKeyRotaryEmbeddings] = None,
attention_biases: Optional[AttentionLinearBiases] = None,
use_bias: bool,
device: Optional[torch.device] = None,
):
Expand All @@ -737,12 +774,10 @@ def __init__(
:param attention_heads:
Attention head configuration.
:param dropout_prob:
Dropout to apply between the self-attention and output layers.
:param attention_scorer:
Attention scorer used to calculate the attention values.
:param hidden_width:
Hidden width of the layer.
:param attention_biases:
ALiBi biases. ALiBi will not be used when set to ``None``.
:param qkv_mode:
Handling mode for query, key and value.
:param rotary_embeds:
Expand All @@ -756,7 +791,6 @@ def __init__(

super().__init__()

self.dropout_prob = dropout_prob
self.attention_heads = attention_heads
if hidden_width % attention_heads._n_query_heads != 0:
raise ValueError(
Expand All @@ -766,13 +800,10 @@ def __init__(

self.head_width = hidden_width // attention_heads._n_query_heads
self.qkv_mode = qkv_mode
self.use_alibi = attention_biases is not None

self.rotary_embeds = rotary_embeds

self.attention = ScaledDotProductAttention(
dropout_prob=dropout_prob, linear_biases=attention_biases
)
self.attention_scorer = attention_scorer

if (
qkv_mode == QkvMode.MERGED_SPLIT_BEFORE
Expand Down Expand Up @@ -877,34 +908,12 @@ def forward(
causal_mask = create_causal_mask(query, key)
combined_mask = combined_mask.merge_mask(causal_mask)

if _TORCH_SDP.get():
attn_mask = combined_mask.logit_mask(query.dtype)

# Add AliBi to the logit mask
if self.use_alibi:
assert self.attention.linear_biases is not None
biases = self.attention.linear_biases.calculate_biases(key.size(-2)).to(
dtype=query.dtype, device=query.device
)
bool_mask = combined_mask.bool_mask
attn_mask = torch.where(bool_mask, biases, attn_mask)

# We can't pass a bool mask, because it is currently broken:
# https://github.com/pytorch/pytorch/issues/103749
attn = F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
)
else:
attn = self.attention(
query=query,
key=key,
value=value,
attention_mask=combined_mask,
)
attn = self.attention_scorer(
query=query,
key=key,
value=value,
attention_mask=combined_mask,
)

attn = combine_heads(attn)

Expand Down
13 changes: 11 additions & 2 deletions curated_transformers/models/albert/layer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from torch import Tensor
from torch.nn import LayerNorm, Module, ModuleList

from ...layers.attention import AttentionHeads, AttentionMask, QkvMode, SelfAttention
from ...layers.attention import (
AttentionHeads,
AttentionMask,
QkvMode,
ScaledDotProductAttention,
SelfAttention,
)
from ...layers.feedforward import PointwiseFeedForward
from ...layers.transformer import (
EncoderLayer,
Expand Down Expand Up @@ -41,7 +47,10 @@ def __init__(
attention_heads=AttentionHeads.uniform(
attention_config.n_query_heads
),
dropout_prob=attention_config.dropout_prob,
attention_scorer=ScaledDotProductAttention(
dropout_prob=attention_config.dropout_prob,
linear_biases=None,
),
hidden_width=layer_config.feedforward.hidden_width,
qkv_mode=QkvMode.SEPARATE,
rotary_embeds=None,
Expand Down
12 changes: 10 additions & 2 deletions curated_transformers/models/bert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from torch import Tensor
from torch.nn import Dropout, LayerNorm

from ...layers.attention import AttentionHeads, QkvMode, SelfAttention
from ...layers.attention import (
AttentionHeads,
QkvMode,
ScaledDotProductAttention,
SelfAttention,
)
from ...layers.feedforward import PointwiseFeedForward
from ...layers.transformer import (
EmbeddingDropouts,
Expand Down Expand Up @@ -77,7 +82,10 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None)
attention_heads=AttentionHeads.uniform(
config.layer.attention.n_query_heads
),
dropout_prob=config.layer.attention.dropout_prob,
attention_scorer=ScaledDotProductAttention(
dropout_prob=config.layer.attention.dropout_prob,
linear_biases=None,
),
hidden_width=config.layer.feedforward.hidden_width,
qkv_mode=QkvMode.SEPARATE,
rotary_embeds=None,
Expand Down
9 changes: 6 additions & 3 deletions curated_transformers/models/falcon/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

import torch
from torch import Tensor
from torch.nn import Dropout, Embedding, LayerNorm, ModuleList
from torch.nn import Dropout, LayerNorm, ModuleList

from ...layers.attention import (
AttentionHeads,
AttentionLinearBiases,
QkvMode,
ScaledDotProductAttention,
SelfAttention,
)
from ...layers.embeddings import QueryKeyRotaryEmbeddings
Expand Down Expand Up @@ -150,12 +151,14 @@ def _create_new_decoder_architecture_layer(
)
return DecoderLayer(
attention_layer=SelfAttention(
attention_biases=attention_biases,
attention_heads=AttentionHeads.key_value_broadcast(
n_query_heads=n_attention_heads,
n_key_value_heads=config.layer.attention.n_key_value_heads,
),
dropout_prob=config.layer.attention.dropout_prob,
attention_scorer=ScaledDotProductAttention(
dropout_prob=config.layer.attention.dropout_prob,
linear_biases=attention_biases,
),
hidden_width=hidden_width,
qkv_mode=QkvMode.MERGED_SPLIT_AFTER,
rotary_embeds=rotary_embeds,
Expand Down
7 changes: 5 additions & 2 deletions curated_transformers/models/falcon/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AttentionMask,
KeyValueCache,
QkvMode,
ScaledDotProductAttention,
SelfAttention,
)
from ...layers.embeddings import QueryKeyRotaryEmbeddings
Expand Down Expand Up @@ -62,8 +63,10 @@ def __init__(
else None
)
self.mha = SelfAttention(
attention_biases=attention_biases,
dropout_prob=attention_config.dropout_prob,
attention_scorer=ScaledDotProductAttention(
dropout_prob=attention_config.dropout_prob,
linear_biases=attention_biases,
),
hidden_width=hidden_width,
attention_heads=AttentionHeads.key_value_broadcast(
n_query_heads=attention_config.n_query_heads,
Expand Down
12 changes: 10 additions & 2 deletions curated_transformers/models/gpt_neox/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from torch import Tensor
from torch.nn import Dropout, LayerNorm, ModuleList

from ...layers.attention import AttentionHeads, QkvMode, SelfAttention
from ...layers.attention import (
AttentionHeads,
QkvMode,
ScaledDotProductAttention,
SelfAttention,
)
from ...layers.embeddings import QueryKeyRotaryEmbeddings
from ...layers.feedforward import PointwiseFeedForward
from ...layers.transformer import (
Expand Down Expand Up @@ -78,7 +83,10 @@ def __init__(
DecoderLayer(
attention_layer=SelfAttention(
attention_heads=AttentionHeads.uniform(n_attention_heads),
dropout_prob=config.layer.attention.dropout_prob,
attention_scorer=ScaledDotProductAttention(
dropout_prob=config.layer.attention.dropout_prob,
linear_biases=None,
),
hidden_width=hidden_width,
qkv_mode=QkvMode.MERGED_SPLIT_BEFORE,
rotary_embeds=QueryKeyRotaryEmbeddings(
Expand Down
Loading

0 comments on commit 499e9d9

Please sign in to comment.