@@ -4453,7 +4453,8 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
4453
4453
4454
4454
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
4455
4455
4456
- const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32 (dst, 0 );
4456
+ const int mode_flags = (ggml_scale_mode) ggml_get_op_params_i32 (dst, 0 );
4457
+ const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF );
4457
4458
cl_kernel kernel = nullptr ;
4458
4459
4459
4460
if (mode == GGML_SCALE_MODE_NEAREST) {
@@ -4484,18 +4485,22 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
4484
4485
const cl_ulong nb02 = src0->nb [2 ];
4485
4486
const cl_ulong nb03 = src0->nb [3 ];
4486
4487
4487
- const int ne00_src = src0->ne [0 ];
4488
- const int ne01_src = src0->ne [1 ];
4488
+ const int ne00 = src0->ne [0 ];
4489
+ const int ne01 = src0->ne [1 ];
4490
+ const int ne02 = src0->ne [2 ];
4491
+ const int ne03 = src0->ne [3 ];
4489
4492
4490
- const int ne10_dst = dst->ne [0 ];
4491
- const int ne11_dst = dst->ne [1 ];
4492
- const int ne12_dst = dst->ne [2 ];
4493
- const int ne13_dst = dst->ne [3 ];
4493
+ const int ne0 = dst->ne [0 ];
4494
+ const int ne1 = dst->ne [1 ];
4495
+ const int ne2 = dst->ne [2 ];
4496
+ const int ne3 = dst->ne [3 ];
4497
+
4498
+ float sf0 = (float )ne0 / ne00;
4499
+ float sf1 = (float )ne1 / ne01;
4500
+ float sf2 = (float )ne2 / ne02;
4501
+ float sf3 = (float )ne3 / ne03;
4494
4502
4495
- const float sf0 = (float )dst->ne [0 ] / src0->ne [0 ];
4496
- const float sf1 = (float )dst->ne [1 ] / src0->ne [1 ];
4497
- const float sf2 = (float )dst->ne [2 ] / src0->ne [2 ];
4498
- const float sf3 = (float )dst->ne [3 ] / src0->ne [3 ];
4503
+ float pixel_offset = 0 .5f ;
4499
4504
4500
4505
CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra_src0->data_device ));
4501
4506
CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &off_src0));
@@ -4507,29 +4512,36 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
4507
4512
CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &nb03));
4508
4513
4509
4514
if (mode == GGML_SCALE_MODE_NEAREST) {
4510
- CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne10_dst ));
4511
- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne11_dst ));
4512
- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne12_dst ));
4513
- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne13_dst ));
4515
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne0 ));
4516
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne1 ));
4517
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne2 ));
4518
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne3 ));
4514
4519
CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (float ), &sf0));
4515
4520
CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (float ), &sf1));
4516
4521
CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (float ), &sf2));
4517
4522
CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (float ), &sf3));
4518
4523
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
4519
- CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne00_src));
4520
- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne01_src));
4521
- CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne10_dst));
4522
- CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne11_dst));
4523
- CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &ne12_dst));
4524
- CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne13_dst));
4524
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
4525
+ sf0 = (float )(ne0 - 1 ) / (ne00 - 1 );
4526
+ sf1 = (float )(ne1 - 1 ) / (ne01 - 1 );
4527
+ pixel_offset = 0 .0f ;
4528
+ }
4529
+
4530
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne00));
4531
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne01));
4532
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne0));
4533
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne1));
4534
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &ne2));
4535
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne3));
4525
4536
CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (float ), &sf0));
4526
4537
CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (float ), &sf1));
4527
4538
CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (float ), &sf2));
4528
4539
CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (float ), &sf3));
4540
+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (float ), &pixel_offset));
4529
4541
}
4530
4542
4531
4543
4532
- size_t dst_total_elements = (size_t )ne10_dst * ne11_dst * ne12_dst * ne13_dst ;
4544
+ size_t dst_total_elements = (size_t )ne0 * ne1 * ne2 * ne3 ;
4533
4545
if (dst_total_elements == 0 ) {
4534
4546
return ;
4535
4547
}
0 commit comments