@@ -425,7 +425,7 @@ struct vk_device_struct {
425
425
vk_pipeline pipeline_div_norepeat[2][2][2];
426
426
427
427
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
428
- vk_pipeline pipeline_upscale_f32 ;
428
+ vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32 ;
429
429
vk_pipeline pipeline_scale_f32;
430
430
vk_pipeline pipeline_sqr_f32;
431
431
vk_pipeline pipeline_sin_f32;
@@ -895,6 +895,7 @@ struct vk_op_conv2d_dw_push_constants {
895
895
896
896
struct vk_op_upscale_push_constants {
897
897
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
898
+ uint32_t ne00; uint32_t ne01;
898
899
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
899
900
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
900
901
float sf0; float sf1; float sf2; float sf3;
@@ -2856,7 +2857,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
2856
2857
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2857
2858
ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2858
2859
2859
- ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
2860
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
2861
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
2862
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
2860
2863
2861
2864
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2862
2865
@@ -6536,8 +6539,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6536
6539
}
6537
6540
return nullptr;
6538
6541
case GGML_OP_UPSCALE:
6539
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
6540
- return ctx->device->pipeline_upscale_f32;
6542
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6543
+ int mode = ggml_get_op_params_i32(dst, 0);
6544
+ switch (mode) {
6545
+ case GGML_SCALE_MODE_NEAREST:
6546
+ return ctx->device->pipeline_upscale_nearest_f32;
6547
+ case GGML_SCALE_MODE_BILINEAR:
6548
+ return ctx->device->pipeline_upscale_bilinear_f32;
6549
+ case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
6550
+ return ctx->device->pipeline_upscale_bilinear_ac_f32;
6551
+ }
6541
6552
}
6542
6553
return nullptr;
6543
6554
case GGML_OP_SCALE:
@@ -7586,14 +7597,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
7586
7597
7587
7598
static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7588
7599
const uint32_t src0_type_size = ggml_type_size(src0->type);
7600
+ const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
7589
7601
7590
- const float sf0 = (float)dst->ne[0] / src0->ne[0];
7591
- const float sf1 = (float)dst->ne[1] / src0->ne[1];
7592
- const float sf2 = (float)dst->ne[2] / src0->ne[2];
7593
- const float sf3 = (float)dst->ne[3] / src0->ne[3];
7602
+ float sf0 = (float)dst->ne[0] / src0->ne[0];
7603
+ float sf1 = (float)dst->ne[1] / src0->ne[1];
7604
+ float sf2 = (float)dst->ne[2] / src0->ne[2];
7605
+ float sf3 = (float)dst->ne[3] / src0->ne[3];
7606
+
7607
+ if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7608
+ sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
7609
+ sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
7610
+ }
7594
7611
7595
7612
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
7596
7613
(uint32_t)ggml_nelements(dst), 0, 0,
7614
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
7597
7615
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7598
7616
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
7599
7617
sf0, sf1, sf2, sf3,
@@ -10578,13 +10596,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10578
10596
case GGML_OP_CLAMP:
10579
10597
return op->src[0]->type == GGML_TYPE_F32;
10580
10598
case GGML_OP_UPSCALE:
10581
- return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
10582
10599
case GGML_OP_ACC:
10583
10600
case GGML_OP_CONCAT:
10584
10601
case GGML_OP_SCALE:
10585
10602
case GGML_OP_PAD:
10603
+ case GGML_OP_ROLL:
10586
10604
case GGML_OP_DIAG_MASK_INF:
10587
- return true;
10588
10605
case GGML_OP_SOFT_MAX:
10589
10606
case GGML_OP_SOFT_MAX_BACK:
10590
10607
case GGML_OP_ARGSORT:
0 commit comments