Skip to content

Commit 9905026

Browse files
[Executorch][llm] Make mask tensor float only for sdpa (#12142)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12131 by @kimishpatel ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/kimishpatel/195/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/kimishpatel/195/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/kimishpatel/194/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/kimishpatel/195/orig @diff-train-skip-merge --------- Co-authored-by: Kimish Patel <kimishpatel@fb.com>
1 parent a7091bf commit 9905026

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ bool validate_flash_attention_args(
5959

6060
ET_CHECK_OR_RETURN_FALSE(
6161
!attn_mask.has_value() ||
62-
attn_mask.value().scalar_type() == query.scalar_type(),
63-
"Attention mask must be a 2D tensor");
62+
attn_mask.value().scalar_type() == ScalarType::Float,
63+
"Attention mask must be a Float tensor");
6464

6565
ET_CHECK_OR_RETURN_FALSE(
6666
is_contiguous_dim_order(query.dim_order().data(), query.dim()),

0 commit comments

Comments
 (0)