diff --git a/curated_transformers/layers/attention.py b/curated_transformers/layers/attention.py index b32561c7..fedec5b4 100644 --- a/curated_transformers/layers/attention.py +++ b/curated_transformers/layers/attention.py @@ -764,7 +764,7 @@ def forward( # # Doing this properly requires a redesign of our AttentionMask # class. - assert attention_mask.bool_mask.size(-2) == 1 + assert attention_mask.bool_mask.size(-2) == 1, "Torch SDP does not support attention masks with non-broadcastable query length yet" return torch.where( attention_mask.bool_mask.transpose(-1, -2), attn_values, 0.0 )