From eefe9000da352cee640943eace8926799bb86cb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 8 Feb 2024 20:02:14 +0100 Subject: [PATCH] Add assertion message Co-authored-by: Madeesh Kannan --- curated_transformers/layers/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/curated_transformers/layers/attention.py b/curated_transformers/layers/attention.py index b32561c7..fedec5b4 100644 --- a/curated_transformers/layers/attention.py +++ b/curated_transformers/layers/attention.py @@ -764,7 +764,7 @@ def forward( # # Doing this properly requires a redesign of our AttentionMask # class. - assert attention_mask.bool_mask.size(-2) == 1 + assert attention_mask.bool_mask.size(-2) == 1, "Torch SDP does not support attention masks with non-broadcastable query length yet" return torch.where( attention_mask.bool_mask.transpose(-1, -2), attn_values, 0.0 )