@@ -633,14 +633,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
633
633
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true );
634
634
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true );
635
635
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true );
636
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true );
637
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true );
638
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true );
639
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true );
640
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true );
641
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true );
642
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true );
643
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true );
636
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx-> support_simdgroup_mm );
637
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx-> support_simdgroup_mm );
638
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx-> support_simdgroup_mm );
639
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx-> support_simdgroup_mm );
640
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx-> support_simdgroup_mm );
641
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx-> support_simdgroup_mm );
642
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx-> support_simdgroup_reduction );
643
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx-> support_simdgroup_reduction );
644
644
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
645
645
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
646
646
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
@@ -772,8 +772,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
772
772
case GGML_OP_TIMESTEP_EMBEDDING:
773
773
case GGML_OP_ARGSORT:
774
774
case GGML_OP_LEAKY_RELU:
775
- case GGML_OP_FLASH_ATTN_EXT:
776
775
return true ;
776
+ case GGML_OP_FLASH_ATTN_EXT:
777
+ return ctx->support_simdgroup_mm ; // TODO: over-restricted for vec-kernels
777
778
case GGML_OP_MUL_MAT:
778
779
case GGML_OP_MUL_MAT_ID:
779
780
return ctx->support_simdgroup_reduction &&
0 commit comments