29
29
30
30
#include " ggml-vulkan-shaders.hpp"
31
31
32
- #define VK_API_VERSION VK_API_VERSION_1_2
33
-
34
32
#define CEIL_DIV (M, N ) (((M) + (N)-1 ) / (N))
35
33
36
34
#define VK_VENDOR_ID_AMD 0x1002
@@ -1614,11 +1612,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1614
1612
CREATE_MM (PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1615
1613
CREATE_MM (PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1616
1614
1617
- CREATE_MM (pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1618
- CREATE_MM (pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1619
-
1620
1615
CREATE_MM2 (pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1621
- CREATE_MM2 (pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1622
1616
CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc , matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
1623
1617
CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc , matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
1624
1618
CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc , matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
@@ -1631,21 +1625,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
1631
1625
CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc , matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3 )
1632
1626
CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc , matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
1633
1627
1634
- CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
1635
1628
CREATE_MM2 (pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
1636
- CREATE_MM2 (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
1637
-
1638
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1639
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1640
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1641
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1642
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1643
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1644
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1645
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1646
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1647
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1648
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1629
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1630
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1631
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1632
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1633
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1634
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1635
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1636
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1637
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1638
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1639
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1649
1640
#undef CREATE_MM
1650
1641
#undef CREATE_MM2
1651
1642
} else
@@ -2287,6 +2278,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
2287
2278
}
2288
2279
#endif
2289
2280
2281
+ VkPhysicalDeviceMaintenance4Features maint4_features {};
2282
+ maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
2283
+ if (maintenance4_support) {
2284
+ last_struct->pNext = (VkBaseOutStructure *)&maint4_features;
2285
+ last_struct = (VkBaseOutStructure *)&maint4_features;
2286
+ device_extensions.push_back (" VK_KHR_maintenance4" );
2287
+ }
2288
+
2290
2289
vkGetPhysicalDeviceFeatures2 (device->physical_device , &device_features2);
2291
2290
2292
2291
device->fp16 = device->fp16 && vk12_features.shaderFloat16 ;
@@ -2662,7 +2661,14 @@ void ggml_vk_instance_init() {
2662
2661
2663
2662
vk_instance_initialized = true ;
2664
2663
2665
- vk::ApplicationInfo app_info{ " ggml-vulkan" , 1 , nullptr , 0 , VK_API_VERSION };
2664
+ uint32_t api_version = vk::enumerateInstanceVersion ();
2665
+
2666
+ if (api_version < VK_API_VERSION_1_2) {
2667
+ std::cerr << " ggml_vulkan: Error: Vulkan 1.2 required." << std::endl;
2668
+ GGML_ABORT (" fatal error" );
2669
+ }
2670
+
2671
+ vk::ApplicationInfo app_info{ " ggml-vulkan" , 1 , nullptr , 0 , api_version };
2666
2672
2667
2673
const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties ();
2668
2674
const bool validation_ext = ggml_vk_instance_validation_ext_available (instance_extensions);
@@ -2972,7 +2978,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
2972
2978
}
2973
2979
}
2974
2980
2975
- GGML_ASSERT (src1_type == GGML_TYPE_F32);
2981
+ GGML_ASSERT (src1_type == GGML_TYPE_F32 || (ctx-> device -> coopmat2 && src1_type == GGML_TYPE_F16) );
2976
2982
2977
2983
switch (src0_type) {
2978
2984
case GGML_TYPE_Q4_0:
@@ -3812,8 +3818,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
3812
3818
src1_uma = d_Qy != nullptr ;
3813
3819
}
3814
3820
3815
- const bool x_non_contig = !ggml_vk_dim01_contiguous (src0);
3816
- // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
3821
+ // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
3822
+ const bool x_non_contig = (ctx->device ->coopmat2 && src0->type == GGML_TYPE_F32) ||
3823
+ !ggml_vk_dim01_contiguous (src0);
3817
3824
const bool y_non_contig = (ctx->device ->coopmat2 && src1->type == GGML_TYPE_F32) ||
3818
3825
!ggml_vk_dim01_contiguous (src1);
3819
3826
@@ -4393,8 +4400,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4393
4400
ids_uma = d_ids != nullptr ;
4394
4401
}
4395
4402
4396
- const bool x_non_contig = !ggml_vk_dim01_contiguous (src0);
4397
- const bool y_non_contig = !ggml_vk_dim01_contiguous (src1);
4403
+ // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
4404
+ const bool x_non_contig = (ctx->device ->coopmat2 && src0->type == GGML_TYPE_F32) ||
4405
+ !ggml_vk_dim01_contiguous (src0);
4406
+ const bool y_non_contig = (ctx->device ->coopmat2 && src1->type == GGML_TYPE_F32) ||
4407
+ !ggml_vk_dim01_contiguous (src1);
4398
4408
4399
4409
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
4400
4410
@@ -4404,7 +4414,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4404
4414
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
4405
4415
4406
4416
if (qx_needs_dequant) {
4407
- GGML_ABORT (" fatal error" );
4417
+ // Fall back to dequant + f16 mulmat
4418
+ mmp = ggml_vk_get_mul_mat_mat_id_pipeline (ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params [0 ]);
4408
4419
}
4409
4420
4410
4421
// Not implemented
0 commit comments