Skip to content

Commit a3738b2

Browse files
0cc4mggerganov
authored andcommitted
vulkan : implement Stable Diffusion operators (ggml/904)
* Fix Vulkan repeat op * Implement Vulkan concat op * Delete old Vulkan shader generator * Implement Vulkan im2col op * Implement Vulkan unary gelu_quick op * Implement Vulkan group_norm op * Implement Vulkan timestep_embedding op * Implement Vulkan upscale op * Fix Vulkan vk_context tensor extra index issue * Fix Vulkan matmul shader parameter bug * Properly fix Vulkan matmul shader parameter bug * Add Vulkan ADD f16 + f32 -> f16 operator support * Implement Vulkan tanh op * Fix Vulkan group count too large Validation error on non-Nvidia GPUs * Throw error when too much memory is requested * Fix another Vulkan group count too large Validation error on non-Nvidia GPUs * Fix matmul MMQ condition * Implement Vulkan pad op * Fix Vulkan crash when tensor is used multiple times in a compute graph * Add Vulkan CONCAT f16 + f16 -> f16 op * Add Vulkan LEAKY_RELU op
1 parent 655858a commit a3738b2

28 files changed

+1032
-293
lines changed

ggml/src/ggml-vulkan.cpp

Lines changed: 605 additions & 235 deletions
Large diffs are not rendered by default.

ggml/src/vulkan-shaders/add.comp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
#include "generic_binary_head.comp"
55

66
void main() {
7-
if (gl_GlobalInvocationID.x >= p.ne) {
7+
const uint idx = get_idx();
8+
9+
if (idx >= p.ne) {
810
return;
911
}
1012

11-
data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
13+
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)]));
1214
}

ggml/src/vulkan-shaders/clamp.comp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
#include "generic_unary_head.comp"
55

66
void main() {
7-
if (gl_GlobalInvocationID.x >= p.ne) {
7+
const uint idx = get_idx();
8+
9+
if (idx >= p.ne) {
810
return;
911
}
1012

11-
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
12-
data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
13+
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
14+
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
1315
}

ggml/src/vulkan-shaders/concat.comp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_binary_head.comp"
5+
6+
void main() {
7+
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
8+
const int dim = p.param3;
9+
10+
if (idx >= p.ne) {
11+
return;
12+
}
13+
14+
const uint i3 = idx / (p.ne22*p.ne21*p.ne20);
15+
const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20;
16+
const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20);
17+
const uint i2_offset = i2*p.ne21*p.ne20;
18+
const uint i1 = (idx - i3_offset - i2_offset) / p.ne20;
19+
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20;
20+
21+
uint o[4] = {0, 0, 0, 0};
22+
o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03));
23+
24+
const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
25+
const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10;
26+
const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20;
27+
28+
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
29+
30+
#ifndef OPTIMIZATION_ERROR_WORKAROUND
31+
data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
32+
#else
33+
data_d[p.d_offset + dst_idx] = is_src0 ? data_a[src0_idx] : data_b[src1_idx];
34+
#endif
35+
}

ggml/src/vulkan-shaders/copy.comp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
#include "generic_unary_head.comp"
55

66
void main() {
7-
if (gl_GlobalInvocationID.x >= p.ne) {
7+
const uint idx = get_idx();
8+
9+
if (idx >= p.ne) {
810
return;
911
}
1012

1113
#ifndef OPTIMIZATION_ERROR_WORKAROUND
12-
data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
14+
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]);
1315
#else
14-
data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = data_a[src0_idx(gl_GlobalInvocationID.x)];
16+
data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)];
1517
#endif
1618
}

ggml/src/vulkan-shaders/div.comp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
#include "generic_binary_head.comp"
55

66
void main() {
7-
if (gl_GlobalInvocationID.x >= p.ne) {
7+
const uint idx = get_idx();
8+
9+
if (idx >= p.ne) {
810
return;
911
}
1012

11-
data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) / FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
13+
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)]));
1214
}

ggml/src/vulkan-shaders/gelu.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
1313
void main() {
1414
const float GELU_COEF_A = 0.044715f;
1515
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
16-
const uint i = gl_GlobalInvocationID.x;
16+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
1717

1818
if (i >= p.KX) {
1919
return;
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const float GELU_QUICK_COEF = -1.702f;
15+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
16+
17+
if (i >= p.KX) {
18+
return;
19+
}
20+
21+
const float x = float(data_a[i]);
22+
data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x))));
23+
}

ggml/src/vulkan-shaders/generic_binary_head.comp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ layout (push_constant) uniform parameter
77
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
88
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
99
uint d_offset;
10-
float param1; float param2;
10+
float param1; float param2; int param3;
1111
} p;
1212

1313
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
@@ -16,6 +16,10 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1616
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
1717
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
1818

19+
uint get_idx() {
20+
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
21+
}
22+
1923
uint src0_idx(uint idx) {
2024
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
2125
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;

ggml/src/vulkan-shaders/generic_unary_head.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
1414
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1515
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
1616

17+
uint get_idx() {
18+
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
19+
}
20+
1721
uint src0_idx(uint idx) {
1822
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
1923
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;

0 commit comments

Comments
 (0)