Skip to content

Commit 7631056

Browse files
committed
Merge pull request opencv#19114 from alalek:issue_18937
2 parents 4107dc7 + 4b3d2c8 commit 7631056

File tree

4 files changed

+117
-51
lines changed

4 files changed

+117
-51
lines changed

modules/core/src/ocl.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2949,6 +2949,15 @@ bool Kernel::empty() const
29492949
return ptr() == 0;
29502950
}
29512951

2952+
static cv::String dumpValue(size_t sz, const void* p)
2953+
{
2954+
if (sz == 4)
2955+
return cv::format("%d / %uu / 0x%08x / %g", *(int*)p, *(int*)p, *(int*)p, *(float*)p);
2956+
if (sz == 8)
2957+
return cv::format("%lld / %lluu / 0x%16llx / %g", *(long long*)p, *(long long*)p, *(long long*)p, *(double*)p);
2958+
return cv::format("%p", p);
2959+
}
2960+
29522961
int Kernel::set(int i, const void* value, size_t sz)
29532962
{
29542963
if (!p || !p->handle)
@@ -2959,7 +2968,7 @@ int Kernel::set(int i, const void* value, size_t sz)
29592968
p->cleanupUMats();
29602969

29612970
cl_int retval = clSetKernelArg(p->handle, (cl_uint)i, sz, value);
2962-
CV_OCL_DBG_CHECK_RESULT(retval, cv::format("clSetKernelArg('%s', arg_index=%d, size=%d, value=%p)", p->name.c_str(), (int)i, (int)sz, (void*)value).c_str());
2971+
CV_OCL_DBG_CHECK_RESULT(retval, cv::format("clSetKernelArg('%s', arg_index=%d, size=%d, value=%s)", p->name.c_str(), (int)i, (int)sz, dumpValue(sz, value).c_str()).c_str());
29632972
if (retval != CL_SUCCESS)
29642973
return -1;
29652974
return i+1;

modules/dnn/src/ocl4dnn/src/math_functions.cpp

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,13 @@ ocl::Image2D ocl4dnnGEMMCopyBufferToImage(UMat buffer, int offset,
8888
size_t global_copy[2];
8989
global_copy[0] = width;
9090
global_copy[1] = height;
91-
oclk_gemm_copy.set(0, ocl::KernelArg::PtrReadOnly(buffer));
92-
oclk_gemm_copy.set(1, image);
93-
oclk_gemm_copy.set(2, offset);
94-
oclk_gemm_copy.set(3, width);
95-
oclk_gemm_copy.set(4, height);
96-
oclk_gemm_copy.set(5, ld);
97-
oclk_gemm_copy.run(2, global_copy, NULL, false);
91+
oclk_gemm_copy
92+
.args(
93+
ocl::KernelArg::PtrReadOnly(buffer),
94+
image, offset,
95+
width, height,
96+
ld)
97+
.run(2, global_copy, NULL, false);
9898
}
9999
} else {
100100
if (!padding)
@@ -112,13 +112,13 @@ ocl::Image2D ocl4dnnGEMMCopyBufferToImage(UMat buffer, int offset,
112112
global_copy[0] = padded_width;
113113
global_copy[1] = padded_height;
114114

115-
oclk_gemm_copy.set(0, ocl::KernelArg::PtrReadOnly(buffer));
116-
oclk_gemm_copy.set(1, image);
117-
oclk_gemm_copy.set(2, offset);
118-
oclk_gemm_copy.set(3, width);
119-
oclk_gemm_copy.set(4, height);
120-
oclk_gemm_copy.set(5, ld);
121-
115+
oclk_gemm_copy
116+
.args(
117+
ocl::KernelArg::PtrReadOnly(buffer),
118+
image, offset,
119+
width, height,
120+
ld)
121+
.run(2, global_copy, NULL, false);
122122
oclk_gemm_copy.run(2, global_copy, NULL, false);
123123
}
124124
}
@@ -465,8 +465,12 @@ static bool ocl4dnnFastBufferGEMM(const CBLAS_TRANSPOSE TransA,
465465
kernel_name += "_float";
466466
}
467467

468+
bool isBetaZero = beta == 0;
469+
468470
String opts = format("-DTYPE=%d", halfPrecisionMode ? TYPE_HALF : TYPE_FLOAT);
469-
ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_buffer_oclsrc, opts);
471+
if (isBetaZero)
472+
opts += " -DZERO_BETA=1";
473+
470474
size_t local[2] = {};
471475
size_t global[2] = {};
472476
if (TransA == CblasNoTrans && TransB != CblasNoTrans && is_small_batch) {
@@ -496,27 +500,37 @@ static bool ocl4dnnFastBufferGEMM(const CBLAS_TRANSPOSE TransA,
496500
local[1] = ly;
497501
}
498502

499-
int arg_idx = 0;
500-
oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrReadOnly(A));
501-
oclk_gemm_float.set(arg_idx++, offA);
502-
oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrReadOnly(B));
503-
oclk_gemm_float.set(arg_idx++, offB);
504-
oclk_gemm_float.set(arg_idx++, ocl::KernelArg::PtrWriteOnly(C));
505-
oclk_gemm_float.set(arg_idx++, offC);
506-
oclk_gemm_float.set(arg_idx++, M);
507-
oclk_gemm_float.set(arg_idx++, N);
508-
oclk_gemm_float.set(arg_idx++, K);
509-
oclk_gemm_float.set(arg_idx++, (float)alpha);
510-
oclk_gemm_float.set(arg_idx++, (float)beta);
511-
512503
bool ret = true;
513-
if (TransB == CblasNoTrans || TransA != CblasNoTrans) {
504+
if (TransB == CblasNoTrans || TransA != CblasNoTrans)
505+
{
506+
// _NN_
514507
int stride = 256;
515508
for (int start_index = 0; start_index < K; start_index += stride) {
516-
oclk_gemm_float.set(arg_idx, start_index);
517-
ret = oclk_gemm_float.run(2, global, local, false);
509+
ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_buffer_oclsrc, opts);
510+
oclk_gemm_float.args(
511+
ocl::KernelArg::PtrReadOnly(A), offA,
512+
ocl::KernelArg::PtrReadOnly(B), offB,
513+
isBetaZero ? ocl::KernelArg::PtrWriteOnly(C) : ocl::KernelArg::PtrReadWrite(C), offC,
514+
M, N, K,
515+
(float)alpha, (float)beta,
516+
start_index
517+
);
518+
ret &= oclk_gemm_float.run(2, global, local, false);
518519
}
519-
} else {
520+
}
521+
else
522+
{
523+
// _NT_
524+
//C.reshape(1,1).setTo(0xfe00 /*FP16 NAN*/); // stable one-line reproducer for https://github.com/opencv/opencv/issues/18937
525+
//C.reshape(1,1).setTo(0); // non-optimal fixup (and not accurate)
526+
ocl::Kernel oclk_gemm_float(kernel_name.c_str(), ocl::dnn::gemm_buffer_oclsrc, opts);
527+
oclk_gemm_float.args(
528+
ocl::KernelArg::PtrReadOnly(A), offA,
529+
ocl::KernelArg::PtrReadOnly(B), offB,
530+
isBetaZero ? ocl::KernelArg::PtrWriteOnly(C) : ocl::KernelArg::PtrReadWrite(C), offC,
531+
M, N, K,
532+
(float)alpha, (float)beta
533+
);
520534
ret = oclk_gemm_float.run(2, global, local, false);
521535
}
522536
return ret;

modules/dnn/src/opencl/gemm_buffer.cl

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@
9090
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
9191
#endif
9292

93+
#ifdef ZERO_BETA
94+
#define BETA_ZERO_CHECK(b0, v) (b0)
95+
#else
96+
#define BETA_ZERO_CHECK(b0, v) (v)
97+
#endif
98+
9399
#define VEC_SIZE 4
94100
#define LWG_HEIGHT 4
95101
#define TILE_M 8
@@ -143,14 +149,14 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)(
143149
int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border;
144150
int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border;
145151

146-
Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);
147-
Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N);
148-
Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);
149-
Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);
150-
Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);
151-
Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);
152-
Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);
153-
Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);
152+
Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0));
153+
Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 1 * N));
154+
Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 2 * N));
155+
Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 3 * N));
156+
Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 4 * N));
157+
Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 5 * N));
158+
Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 6 * N));
159+
Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 7 * N));
154160

155161
int end_index = min(start_index + 256, K);
156162
int w = start_index;
@@ -579,7 +585,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
579585
output = (local_x == 5) ? _dot.s5 : output; \
580586
output = (local_x == 6) ? _dot.s6 : output; \
581587
output = (local_x == 7) ? _dot.s7 : output; \
582-
dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \
588+
dst_write0[0] = BETA_ZERO_CHECK(alpha * output, mad(output, alpha, beta * dst_write0[0])); \
583589
dst_write0 += N;
584590

585591
if(global_x < N && global_y * 8 < M) {
@@ -765,7 +771,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
765771
output = (local_x == 5) ? _dot.s5 : output; \
766772
output = (local_x == 6) ? _dot.s6 : output; \
767773
output = (local_x == 7) ? _dot.s7 : output; \
768-
dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \
774+
dst_write0[0] = BETA_ZERO_CHECK(alpha * output, mad(output, alpha, beta * dst_write0[0])); \
769775
dst_write0 += N;
770776

771777
if(global_x < N && global_y * 8 < M) {
@@ -819,8 +825,9 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(
819825
const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};
820826
#pragma unroll
821827
for(int j = 0; j < rows; ++j) {
822-
dot0[j] += b0 * vload4(i, srcb_read + j * K);
823-
dot1[j] += b1 * vload4(i, srcb_read + j * K);
828+
Dtype4 a = vload4(i, srcb_read + j * K);
829+
dot0[j] += b0 * a;
830+
dot1[j] += b1 * a;
824831
}
825832

826833
i += get_local_size(0);
@@ -859,11 +866,19 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(
859866
}
860867
}
861868

869+
barrier(CLK_LOCAL_MEM_FENCE);
862870
if(lid == 0) {
863871
#pragma unroll
864872
for(int j = 0; j < rows; ++j) {
865-
dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];
866-
dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];
873+
#ifdef ZERO_BETA
874+
Dtype a0 = alpha * work_each0[j];
875+
Dtype a1 = alpha * work_each1[j];
876+
#else
877+
Dtype a0 = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];
878+
Dtype a1 = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];
879+
#endif
880+
dstc0[(x_gid * 4 + j)] = a0;
881+
dstc1[(x_gid * 4 + j)] = a1;
867882
}
868883
}
869884
}
@@ -952,9 +967,15 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)(
952967
}
953968
}
954969

955-
if(lid == 0) {
970+
if(lid == 0)
971+
{
972+
#ifdef ZERO_BETA
973+
dstc0[x_gid] = alpha * work0[0];
974+
dstc1[x_gid] = alpha * work1[0];
975+
#else
956976
dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
957977
dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
978+
#endif
958979
}
959980
}
960981
}
@@ -1058,10 +1079,17 @@ void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)(
10581079
if(lid == 0) {
10591080
#pragma unroll
10601081
for(int j = 0; j < rows; ++j) {
1082+
#ifdef ZERO_BETA
1083+
dstc0[(x_gid * 4 + j)] = alpha * work_each0[j];
1084+
dstc1[(x_gid * 4 + j)] = alpha * work_each1[j];
1085+
dstc2[(x_gid * 4 + j)] = alpha * work_each2[j];
1086+
dstc3[(x_gid * 4 + j)] = alpha * work_each3[j];
1087+
#else
10611088
dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];
10621089
dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];
10631090
dstc2[(x_gid * 4 + j)] = alpha * work_each2[j] + beta * dstc2[(x_gid * 4 + j)];
10641091
dstc3[(x_gid * 4 + j)] = alpha * work_each3[j] + beta * dstc3[(x_gid * 4 + j)];
1092+
#endif
10651093
}
10661094
}
10671095
}
@@ -1179,10 +1207,17 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)(
11791207
}
11801208

11811209
if(lid == 0) {
1210+
#ifdef ZERO_BETA
1211+
dstc0[x_gid] = alpha * work0[0];
1212+
dstc1[x_gid] = alpha * work1[0];
1213+
dstc2[x_gid] = alpha * work2[0];
1214+
dstc3[x_gid] = alpha * work3[0];
1215+
#else
11821216
dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
11831217
dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
11841218
dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];
11851219
dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid];
1220+
#endif
11861221
}
11871222
}
11881223
}
@@ -1320,6 +1355,16 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(
13201355
}
13211356

13221357
if(lid == 0) {
1358+
#ifdef ZERO_BETA
1359+
dstc0[x_gid] = alpha * work0[0];
1360+
dstc1[x_gid] = alpha * work1[0];
1361+
dstc2[x_gid] = alpha * work2[0];
1362+
dstc3[x_gid] = alpha * work3[0];
1363+
dstc4[x_gid] = alpha * work4[0];
1364+
dstc5[x_gid] = alpha * work5[0];
1365+
dstc6[x_gid] = alpha * work6[0];
1366+
dstc7[x_gid] = alpha * work7[0];
1367+
#else
13231368
dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
13241369
dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
13251370
dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];
@@ -1328,6 +1373,7 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(
13281373
dstc5[x_gid] = alpha * work5[0] + beta * dstc5[x_gid];
13291374
dstc6[x_gid] = alpha * work6[0] + beta * dstc6[x_gid];
13301375
dstc7[x_gid] = alpha * work7[0] + beta * dstc7[x_gid];
1376+
#endif
13311377
}
13321378
}
13331379
#undef SLM_SIZE

modules/dnn/test/test_onnx_importer.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -718,9 +718,6 @@ TEST_P(Test_ONNX_layers, Conv1d_variable_weight_bias)
718718

719719
TEST_P(Test_ONNX_layers, GatherMultiOutput)
720720
{
721-
if (cvtest::skipUnstableTests && backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16)
722-
throw SkipTestException("Skip unstable test: https://github.com/opencv/opencv/issues/18937");
723-
724721
#if defined(INF_ENGINE_RELEASE)
725722
if (target == DNN_TARGET_MYRIAD)
726723
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE);

0 commit comments

Comments
 (0)