Skip to content

Commit 74bb294

Browse files
Aclyggerganov
authored andcommitted
vulkan : implement bilinear interpolation (ggml/1291)
ggml-ci
1 parent 3e303b1 commit 74bb294

File tree

2 files changed

+96
-15
lines changed

2 files changed

+96
-15
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ struct vk_device_struct {
425425
vk_pipeline pipeline_div_norepeat[2][2][2];
426426

427427
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
428-
vk_pipeline pipeline_upscale_f32;
428+
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
429429
vk_pipeline pipeline_scale_f32;
430430
vk_pipeline pipeline_sqr_f32;
431431
vk_pipeline pipeline_sin_f32;
@@ -895,6 +895,7 @@ struct vk_op_conv2d_dw_push_constants {
895895

896896
struct vk_op_upscale_push_constants {
897897
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
898+
uint32_t ne00; uint32_t ne01;
898899
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
899900
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
900901
float sf0; float sf1; float sf2; float sf3;
@@ -2856,7 +2857,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
28562857
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
28572858
ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
28582859

2859-
ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
2860+
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
2861+
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
2862+
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
28602863

28612864
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
28622865

@@ -6536,8 +6539,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
65366539
}
65376540
return nullptr;
65386541
case GGML_OP_UPSCALE:
6539-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
6540-
return ctx->device->pipeline_upscale_f32;
6542+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6543+
int mode = ggml_get_op_params_i32(dst, 0);
6544+
switch (mode) {
6545+
case GGML_SCALE_MODE_NEAREST:
6546+
return ctx->device->pipeline_upscale_nearest_f32;
6547+
case GGML_SCALE_MODE_BILINEAR:
6548+
return ctx->device->pipeline_upscale_bilinear_f32;
6549+
case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
6550+
return ctx->device->pipeline_upscale_bilinear_ac_f32;
6551+
}
65416552
}
65426553
return nullptr;
65436554
case GGML_OP_SCALE:
@@ -7586,14 +7597,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
75867597

75877598
static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
75887599
const uint32_t src0_type_size = ggml_type_size(src0->type);
7600+
const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
75897601

7590-
const float sf0 = (float)dst->ne[0] / src0->ne[0];
7591-
const float sf1 = (float)dst->ne[1] / src0->ne[1];
7592-
const float sf2 = (float)dst->ne[2] / src0->ne[2];
7593-
const float sf3 = (float)dst->ne[3] / src0->ne[3];
7602+
float sf0 = (float)dst->ne[0] / src0->ne[0];
7603+
float sf1 = (float)dst->ne[1] / src0->ne[1];
7604+
float sf2 = (float)dst->ne[2] / src0->ne[2];
7605+
float sf3 = (float)dst->ne[3] / src0->ne[3];
7606+
7607+
if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7608+
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
7609+
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
7610+
}
75947611

75957612
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
75967613
(uint32_t)ggml_nelements(dst), 0, 0,
7614+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
75977615
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
75987616
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
75997617
sf0, sf1, sf2, sf3,
@@ -10578,13 +10596,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1057810596
case GGML_OP_CLAMP:
1057910597
return op->src[0]->type == GGML_TYPE_F32;
1058010598
case GGML_OP_UPSCALE:
10581-
return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
1058210599
case GGML_OP_ACC:
1058310600
case GGML_OP_CONCAT:
1058410601
case GGML_OP_SCALE:
1058510602
case GGML_OP_PAD:
10603+
case GGML_OP_ROLL:
1058610604
case GGML_OP_DIAG_MASK_INF:
10587-
return true;
1058810605
case GGML_OP_SOFT_MAX:
1058910606
case GGML_OP_SOFT_MAX_BACK:
1059010607
case GGML_OP_ARGSORT:

ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
layout (push_constant) uniform parameter
44
{
55
uint ne; uint a_offset; uint d_offset;
6+
uint ne00; uint ne01;
67
uint nb00; uint nb01; uint nb02; uint nb03;
78
uint ne10; uint ne11; uint ne12; uint ne13;
89
float sf0; float sf1; float sf2; float sf3;
@@ -15,6 +16,61 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
1516
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1617
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
1718

19+
// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
20+
#define NEAREST 0
21+
#define BILINEAR 1
22+
#define ALIGN_CORNERS (1 << 8)
23+
24+
layout (constant_id = 0) const uint scale_mode = 0;
25+
26+
float fetch_nearest(uint i10, uint i11, uint i12, uint i13) {
27+
const uint i00 = uint(i10 / p.sf0);
28+
const uint i01 = uint(i11 / p.sf1);
29+
const uint i02 = uint(i12 / p.sf2);
30+
const uint i03 = uint(i13 / p.sf3);
31+
32+
return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00];
33+
}
34+
35+
float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {
36+
const uint i02 = uint(i12 / p.sf2);
37+
const uint i03 = uint(i13 / p.sf3);
38+
const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;
39+
40+
const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00];
41+
const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00];
42+
const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00];
43+
const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00];
44+
45+
return
46+
v00 * (1.0-d.x) * (1.0-d.y) +
47+
v01 * d.x * (1.0-d.y) +
48+
v10 * (1.0-d.x) * d.y +
49+
v11 * d.x * d.y;
50+
}
51+
52+
float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
53+
const ivec2 ne0 = ivec2(p.ne00, p.ne01);
54+
55+
const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5;
56+
const vec2 c0f = floor(c);
57+
const vec2 d = c - c0f;
58+
const ivec2 c0 = max(ivec2(c0f), 0);
59+
const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1);
60+
61+
return fetch_bilinear(c0, c1, d, i12, i13);
62+
}
63+
64+
float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) {
65+
const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1);
66+
const vec2 c0f = floor(c);
67+
const vec2 d = c - c0f;
68+
const ivec2 c0 = ivec2(c0f);
69+
const ivec2 c1 = c0 + 1;
70+
71+
return fetch_bilinear(c0, c1, d, i12, i13);
72+
}
73+
1874
void main() {
1975
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
2076

@@ -27,10 +83,18 @@ void main() {
2783
const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
2884
const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
2985

30-
const uint i00 = uint(i10 / p.sf0);
31-
const uint i01 = uint(i11 / p.sf1);
32-
const uint i02 = uint(i12 / p.sf2);
33-
const uint i03 = uint(i13 / p.sf3);
86+
float result;
87+
switch (scale_mode) {
88+
case NEAREST:
89+
result = fetch_nearest(i10, i11, i12, i13);
90+
break;
91+
case BILINEAR:
92+
result = interpolate_bilinear(i10, i11, i12, i13);
93+
break;
94+
case BILINEAR | ALIGN_CORNERS:
95+
result = interpolate_bilinear_align_corners(i10, i11, i12, i13);
96+
break;
97+
}
3498

35-
data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
99+
data_d[p.d_offset + idx] = D_TYPE(result);
36100
}

0 commit comments

Comments
 (0)