We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a7091bf commit 9905026Copy full SHA for 9905026
extension/llm/custom_ops/op_sdpa.cpp
@@ -59,8 +59,8 @@ bool validate_flash_attention_args(
59
60
ET_CHECK_OR_RETURN_FALSE(
61
!attn_mask.has_value() ||
62
- attn_mask.value().scalar_type() == query.scalar_type(),
63
- "Attention mask must be a 2D tensor");
+ attn_mask.value().scalar_type() == ScalarType::Float,
+ "Attention mask must be a Float tensor");
64
65
66
is_contiguous_dim_order(query.dim_order().data(), query.dim()),
0 commit comments