Skip to content

Commit

Permalink
Clear output of Torch SDPA for masked pieces (#360)
Browse files Browse the repository at this point in the history
* Clear output of Torch SDPA for masked pieces

Since Torch 2.1, the Torch memory-efficient SDPA GPU kernel returns NaN
for pieces that are completely masked out. This leads to NaN propagation
in the next attention layer, because masked pieces get an attention of
zero, but zero times NaN is still NaN.

In this we fix this by setting masked tokens to zero to clear out any
NaNs.

We currently rely on the query dimension of the mask to be singular, but
in the future we should probably redesign the `AttentionMask` class to
account for the differences between attention masks and causal masks.

* black

* Update MyPy version to one that supports recent PyTorch

* Comment typos and fixes

* Add assertion message

Co-authored-by: Madeesh Kannan <[email protected]>

* black

---------

Co-authored-by: Madeesh Kannan <[email protected]>
  • Loading branch information
danieldk and shadeMe committed Feb 8, 2024
1 parent 23f3a1b commit f9da3b5
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- name: Configure Python version
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.11"
architecture: x64

- name: black
Expand Down
57 changes: 45 additions & 12 deletions curated_transformers/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ def forward(
key: Tensor,
value: Tensor,
attention_mask: AttentionMask,
use_causal_mask: bool,
) -> Tensor:
"""
Apply attention scores to the given key, query and value.
Expand All @@ -669,6 +670,8 @@ def forward(
Attention mask. Sequence elements for which the corresponding mask
element is set to ``False`` are ignored in attention.
:param use_causal_mask:
Mask out succeeding sequence elements when ``True``.
:returns:
Attention values.
Expand Down Expand Up @@ -712,27 +715,61 @@ def forward(
key: Tensor,
value: Tensor,
attention_mask: AttentionMask,
use_causal_mask: bool,
) -> Tensor:
combined_mask = attention_mask
if use_causal_mask:
causal_mask = create_causal_mask(query, key)
combined_mask = combined_mask.merge_mask(causal_mask)

if _TORCH_SDP.get():
attn_mask = attention_mask.logit_mask(query.dtype)
logit_mask = combined_mask.logit_mask(query.dtype)

# Add AliBi to the logit mask
if self.linear_biases is not None:
biases = self.linear_biases.calculate_biases(key.size(-2)).to(
dtype=query.dtype, device=query.device
)
bool_mask = attention_mask.bool_mask
attn_mask = torch.where(bool_mask, biases, attn_mask)
bool_mask = combined_mask.bool_mask
logit_mask = torch.where(bool_mask, biases, logit_mask)

# We can't pass a bool mask, because it is currently broken:
# https://github.com/pytorch/pytorch/issues/103749
return F.scaled_dot_product_attention(
attn_values = F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
attn_mask=logit_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
)

# Torch SDP returns NaNs for pieces where every is piece masked out.
# These errors propagate because zero attention times NaN is NaN.
# Since the representations of these pieces don't matter anyway, we
# will just zero them out.
#
# One issue is that values have shape
#
# [batch_len, n_heads, key_len, hidden_size]
#
# whereas masks have the shape
#
# [batch_len, 1, query_len, key_len]
#
# So we can only do this when we have attention masks where
# the query length it not specified, which are typically 'pure'
# attention masks (not causal maskes or combined masks):
#
# [batch_len, 1, 1, key_len]
#
# Doing this properly requires a redesign of our AttentionMask
# class.
assert (
attention_mask.bool_mask.size(-2) == 1
), "Torch SDP does not support attention masks with non-broadcastable query length yet"
return torch.where(
attention_mask.bool_mask.transpose(-1, -2), attn_values, 0.0
)
else:
width = key.shape[-1]
attn_scores = query @ key.transpose(-2, -1)
Expand All @@ -741,7 +778,7 @@ def forward(
if self.linear_biases is not None:
attn_scores = self.linear_biases(attention_scores=attn_scores)

attn_scores = attention_mask.apply_logit_mask(attn_scores)
attn_scores = combined_mask.apply_logit_mask(attn_scores)
attn_weights = attn_scores.softmax(dim=-1)
attn_values = self.dropout(attn_weights @ value)

Expand Down Expand Up @@ -903,16 +940,12 @@ def forward(
key = torch.cat([cache_k, key], dim=-2)
value = torch.cat([cache_v, value], dim=-2)

combined_mask = attention_mask
if use_causal_mask:
causal_mask = create_causal_mask(query, key)
combined_mask = combined_mask.merge_mask(causal_mask)

attn = self.attention_scorer(
query=query,
key=key,
value=value,
attention_mask=combined_mask,
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
)

attn = combine_heads(attn)
Expand Down
8 changes: 5 additions & 3 deletions curated_transformers/models/falcon/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ def __init__(
n_key_value_heads=attention_config.n_key_value_heads,
),
rotary_embeds=rotary_embeds,
qkv_mode=QkvMode.MERGED_SPLIT_AFTER
if attention_config.n_key_value_heads == 1
else QkvMode.MERGED_SPLIT_BEFORE,
qkv_mode=(
QkvMode.MERGED_SPLIT_AFTER
if attention_config.n_key_value_heads == 1
else QkvMode.MERGED_SPLIT_BEFORE
),
use_bias=attention_config.use_bias,
device=device,
)
Expand Down
11 changes: 6 additions & 5 deletions curated_transformers/repository/hf_hub.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import warnings
from tempfile import NamedTemporaryFile
from typing import IO, Any, AnyStr, Dict, Iterator, List, Optional
from typing import IO, Any, AnyStr, Dict, Iterator, List, Optional, Union

import huggingface_hub
from huggingface_hub import CommitOperationAdd, HfApi
Expand All @@ -11,6 +11,7 @@
RevisionNotFoundError,
)
from requests import HTTPError, ReadTimeout # type: ignore
from typing_extensions import Buffer

from ..repository.file import LocalFile, RepositoryFile
from .repository import Repository
Expand Down Expand Up @@ -137,9 +138,9 @@ def __init__(self, repo: HfHubRepository):
"""
super().__init__()
self._repo = repo
self._file_mappings: Dict[
str, IO
] = {} # Maps remote file paths to local temporary files
self._file_mappings: Dict[str, IO] = (
{}
) # Maps remote file paths to local temporary files

def open(self, path: str, mode: str, encoding: Optional[str] = None) -> IO:
if path in self._file_mappings:
Expand Down Expand Up @@ -297,7 +298,7 @@ def truncate(self, size: int = None) -> int: # type:ignore
def writable(self) -> bool:
return self._temp_file.writable()

def write(self, s: AnyStr) -> int:
def write(self, s: Union[Any, Buffer, str]) -> int:
return self._temp_file.write(s)

def writelines(self, lines: List[AnyStr]) -> None: # type:ignore
Expand Down
4 changes: 1 addition & 3 deletions curated_transformers/tests/models/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ class JITMethod(Enum):
TorchCompile = 1
TorchScriptTrace = 2

def convert(
self, model: Module, with_torch_sdp: bool, *args
) -> Tuple[
def convert(self, model: Module, with_torch_sdp: bool, *args) -> Tuple[
Union[Module, torch.ScriptModule],
Callable[[Union[ModelOutput, Dict[str, torch.Tensor]]], Tensor],
]:
Expand Down
6 changes: 2 additions & 4 deletions curated_transformers/tokenizers/legacy/legacy_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,10 @@ def _convert_strings(
@abstractmethod
def _decode(
self, input: Iterable[Iterable[int]], skip_special_pieces: bool
) -> List[str]:
...
) -> List[str]: ...

@abstractmethod
def _encode(self, input: Iterable[MergedInputChunks]) -> PiecesWithIds:
...
def _encode(self, input: Iterable[MergedInputChunks]) -> PiecesWithIds: ...


class AddBosEosPreEncoder(PreEncoder):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ tokenizers>=0.13.3
torch>=1.12.0

# Development dependencies
mypy>=0.990,<1.1.0; platform_machine != "aarch64"
mypy>=1.5.0,<1.6.0; platform_machine != "aarch64"
pytest

0 comments on commit f9da3b5

Please sign in to comment.