diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 1788f7a46..deac4d8c9 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -163,20 +163,33 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { bool mmq_supported; switch (type) { + case GGML_TYPE_Q2_K: mmq_supported = ne11 < 384; break; + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + mmq_supported = ne11 < 1536; + break; + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ2_K_R4: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_IQ5_K_R4: + mmq_supported = ne11 < 1024; + break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: @@ -188,22 +201,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ2_KS: - case GGML_TYPE_IQ2_K: - case GGML_TYPE_IQ3_K: - case GGML_TYPE_IQ4_K: - case GGML_TYPE_IQ5_K: - case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: mmq_supported = true; break; - case GGML_TYPE_IQ2_K_R4: - case GGML_TYPE_IQ3_K_R4: - case GGML_TYPE_IQ4_K_R4: - case GGML_TYPE_IQ5_K_R4: - mmq_supported = ne11 < 1024; - break; default: mmq_supported = false; break;