Skip to content

Commit 10a0351

Browse files
authored
vulkan: add RTE variants for glu/add/sub/mul/div (#14653)
1 parent 68e37a6 commit 10a0351

File tree

8 files changed

+90
-32
lines changed

8 files changed

+90
-32
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,10 +2835,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
28352835
return s;
28362836
};
28372837

2838+
bool rte = device->float_controls_rte_fp16;
28382839
#define CREATE_BINARY(name, namemod, spec) \
28392840
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
28402841
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
2841-
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
2842+
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
28422843
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
28432844

28442845
CREATE_BINARY(add, , {0})
@@ -2890,8 +2891,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
28902891
#undef CREATE_UNARY
28912892

28922893
#define CREATE_GLU(name) \
2893-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2894-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
2894+
if (device->float_controls_rte_fp16) { \
2895+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2896+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2897+
} else { \
2898+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2899+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2900+
}
28952901

28962902
CREATE_GLU(geglu)
28972903
CREATE_GLU(reglu)

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
#version 450
22

3-
#if RTE16
4-
#extension GL_EXT_spirv_intrinsics : enable
5-
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
6-
#endif // RTE16
7-
3+
#include "rte.comp"
84
#include "types.comp"
95

106
#if defined(SET_ROWS) && QUANT_K == 1

ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#extension GL_EXT_shader_16bit_storage : require
22
#extension GL_EXT_control_flow_attributes : require
33

4+
#include "rte.comp"
5+
46
layout (push_constant) uniform parameter
57
{
68
uint ne;

ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#extension GL_EXT_shader_16bit_storage : require
22

3+
#include "rte.comp"
4+
35
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
46

57
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};

ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
#version 450
22

33
#extension GL_EXT_shader_16bit_storage : require
4-
#extension GL_EXT_spirv_intrinsics: enable
54
#extension GL_EXT_control_flow_attributes : require
65

7-
#if RTE16
8-
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
9-
#endif
6+
#include "rte.comp"
107

118
layout (push_constant) uniform parameter
129
{

ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
#include "types.comp"
22

33
#extension GL_EXT_shader_16bit_storage : require
4-
#extension GL_EXT_spirv_intrinsics: enable
54

6-
#if RTE16
7-
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
8-
#endif
5+
#include "rte.comp"
96

107
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
118

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
#if RTE16
3+
#extension GL_EXT_spirv_intrinsics : enable
4+
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
5+
#endif // RTE16

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,10 @@ void process_shaders() {
537537
for (auto src0_f16 : {false, true}) {
538538
for (auto src1_f16 : {false, true}) {
539539
for (auto dst_f16 : {false, true}) {
540-
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
541-
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
540+
for (auto rte : {false, true}) {
541+
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
542+
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
543+
}
542544
}
543545
}
544546
}
@@ -592,16 +594,19 @@ void process_shaders() {
592594
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
593595
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
594596

595-
string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
596-
string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
597-
string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
598-
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
599-
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
600-
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
601-
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
602-
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
603-
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
604-
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
597+
for (auto rte : {false, true}) {
598+
std::string suffix = rte ? "_rte" : "";
599+
string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
600+
string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
601+
string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
602+
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
603+
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
604+
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
605+
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
606+
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
607+
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
608+
string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
609+
}
605610

606611
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
607612
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -709,11 +714,59 @@ void write_output_files() {
709714
std::remove(path.c_str());
710715
}
711716
}
717+
718+
std::string suffixes[2] = {"_f32", "_f16"};
712719
for (const char *op : {"add", "sub", "mul", "div"}) {
713-
fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
714-
fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
715-
fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
716-
fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
720+
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
721+
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
722+
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
723+
std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = ";
724+
for (uint32_t t0 = 0; t0 < 2; ++t0) {
725+
if (t0 == 0) {
726+
data += "{";
727+
len += "{";
728+
}
729+
for (uint32_t t1 = 0; t1 < 2; ++t1) {
730+
if (t1 == 0) {
731+
data += "{";
732+
len += "{";
733+
}
734+
for (uint32_t t2 = 0; t2 < 2; ++t2) {
735+
if (t2 == 0) {
736+
data += "{";
737+
len += "{";
738+
}
739+
for (uint32_t rte = 0; rte < 2; ++rte) {
740+
if (rte == 0) {
741+
data += "{";
742+
len += "{";
743+
}
744+
data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
745+
len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
746+
data += "_data,";
747+
len += "_len,";
748+
if (rte == 1) {
749+
data += "}, ";
750+
len += "}, ";
751+
}
752+
}
753+
if (t2 == 1) {
754+
data += "}, ";
755+
len += "}, ";
756+
}
757+
}
758+
if (t1 == 1) {
759+
data += "}, ";
760+
len += "}, ";
761+
}
762+
}
763+
if (t0 == 1) {
764+
data += "};\n";
765+
len += "};\n";
766+
}
767+
}
768+
fprintf(src, data.c_str());
769+
fprintf(src, len.c_str());
717770
}
718771
fclose(hdr);
719772
fclose(src);

0 commit comments

Comments
 (0)