90
90
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
91
91
#endif
92
92
93
+ #ifdef ZERO_BETA
94
+ #define BETA_ZERO_CHECK (b0 , v ) (b0)
95
+ #else
96
+ #define BETA_ZERO_CHECK (b0 , v ) (v)
97
+ #endif
98
+
93
99
#define VEC_SIZE 4
94
100
#define LWG_HEIGHT 4
95
101
#define TILE_M 8
@@ -143,14 +149,14 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)(
143
149
int row6 = mad24 (global_y , TILE_M , 6 ) < M ? 6 : border ;
144
150
int row7 = mad24 (global_y , TILE_M , 7 ) < M ? 7 : border ;
145
151
146
- Dtype4 dot00 = (start_index != 0 ) ? vload4 (0 , dst_write0 ) : beta * vload4 (0 , dst_write0 );
147
- Dtype4 dot01 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 1 * N ) : beta * vload4 (0 , dst_write0 + 1 * N );
148
- Dtype4 dot02 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 2 * N ) : beta * vload4 (0 , dst_write0 + 2 * N );
149
- Dtype4 dot03 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 3 * N ) : beta * vload4 (0 , dst_write0 + 3 * N );
150
- Dtype4 dot04 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 4 * N ) : beta * vload4 (0 , dst_write0 + 4 * N );
151
- Dtype4 dot05 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 5 * N ) : beta * vload4 (0 , dst_write0 + 5 * N );
152
- Dtype4 dot06 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 6 * N ) : beta * vload4 (0 , dst_write0 + 6 * N );
153
- Dtype4 dot07 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 7 * N ) : beta * vload4 (0 , dst_write0 + 7 * N );
152
+ Dtype4 dot00 = (start_index != 0 ) ? vload4 (0 , dst_write0 ) : BETA_ZERO_CHECK (( Dtype4 ) 0 , beta * vload4 (0 , dst_write0 ) );
153
+ Dtype4 dot01 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 1 * N ) : BETA_ZERO_CHECK (( Dtype4 ) 0 , beta * vload4 (0 , dst_write0 + 1 * N ) );
154
+ Dtype4 dot02 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 2 * N ) : BETA_ZERO_CHECK (( Dtype4 ) 0 , beta * vload4 (0 , dst_write0 + 2 * N ) );
155
+ Dtype4 dot03 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 3 * N ) : BETA_ZERO_CHECK (( Dtype4 ) 0 , beta * vload4 (0 , dst_write0 + 3 * N ) );
156
+ Dtype4 dot04 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 4 * N ) : BETA_ZERO_CHECK (( Dtype4 ) 0 , beta * vload4 (0 , dst_write0 + 4 * N ) );
157
+ Dtype4 dot05 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 5 * N ) : BETA_ZERO_CHECK (( Dtype4 ) 0 , beta * vload4 (0 , dst_write0 + 5 * N ) );
158
+ Dtype4 dot06 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 6 * N ) : BETA_ZERO_CHECK (( Dtype4 ) 0 , beta * vload4 (0 , dst_write0 + 6 * N ) );
159
+ Dtype4 dot07 = (start_index != 0 ) ? vload4 (0 , dst_write0 + 7 * N ) : BETA_ZERO_CHECK (( Dtype4 ) 0 , beta * vload4 (0 , dst_write0 + 7 * N ) );
154
160
155
161
int end_index = min (start_index + 256 , K );
156
162
int w = start_index ;
@@ -579,7 +585,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
579
585
output = (local_x == 5) ? _dot.s5 : output; \
580
586
output = (local_x == 6) ? _dot.s6 : output; \
581
587
output = (local_x == 7) ? _dot.s7 : output; \
582
- dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \
588
+ dst_write0[0] = BETA_ZERO_CHECK(alpha * output, mad(output, alpha, beta * dst_write0[0]) ); \
583
589
dst_write0 += N;
584
590
585
591
if (global_x < N && global_y * 8 < M ) {
@@ -765,7 +771,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
765
771
output = (local_x == 5) ? _dot.s5 : output; \
766
772
output = (local_x == 6) ? _dot.s6 : output; \
767
773
output = (local_x == 7) ? _dot.s7 : output; \
768
- dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \
774
+ dst_write0[0] = BETA_ZERO_CHECK(alpha * output, mad(output, alpha, beta * dst_write0[0]) ); \
769
775
dst_write0 += N;
770
776
771
777
if (global_x < N && global_y * 8 < M ) {
@@ -819,8 +825,9 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(
819
825
const Dtype4 b1 = {srca_read1 [i * 4 ], srca_read1 [(i * 4 + 1 )], srca_read1 [(i * 4 + 2 )], srca_read1 [(i * 4 + 3 )]};
820
826
#pragma unroll
821
827
for (int j = 0 ; j < rows ; ++ j ) {
822
- dot0 [j ] += b0 * vload4 (i , srcb_read + j * K );
823
- dot1 [j ] += b1 * vload4 (i , srcb_read + j * K );
828
+ Dtype4 a = vload4 (i , srcb_read + j * K );
829
+ dot0 [j ] += b0 * a ;
830
+ dot1 [j ] += b1 * a ;
824
831
}
825
832
826
833
i += get_local_size (0 );
@@ -859,11 +866,19 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(
859
866
}
860
867
}
861
868
869
+ barrier (CLK_LOCAL_MEM_FENCE );
862
870
if (lid == 0 ) {
863
871
#pragma unroll
864
872
for (int j = 0 ; j < rows ; ++ j ) {
865
- dstc0 [(x_gid * 4 + j )] = alpha * work_each0 [j ] + beta * dstc0 [(x_gid * 4 + j )];
866
- dstc1 [(x_gid * 4 + j )] = alpha * work_each1 [j ] + beta * dstc1 [(x_gid * 4 + j )];
873
+ #ifdef ZERO_BETA
874
+ Dtype a0 = alpha * work_each0 [j ];
875
+ Dtype a1 = alpha * work_each1 [j ];
876
+ #else
877
+ Dtype a0 = alpha * work_each0 [j ] + beta * dstc0 [(x_gid * 4 + j )];
878
+ Dtype a1 = alpha * work_each1 [j ] + beta * dstc1 [(x_gid * 4 + j )];
879
+ #endif
880
+ dstc0 [(x_gid * 4 + j )] = a0 ;
881
+ dstc1 [(x_gid * 4 + j )] = a1 ;
867
882
}
868
883
}
869
884
}
@@ -952,9 +967,15 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)(
952
967
}
953
968
}
954
969
955
- if (lid == 0 ) {
970
+ if (lid == 0 )
971
+ {
972
+ #ifdef ZERO_BETA
973
+ dstc0 [x_gid ] = alpha * work0 [0 ];
974
+ dstc1 [x_gid ] = alpha * work1 [0 ];
975
+ #else
956
976
dstc0 [x_gid ] = alpha * work0 [0 ] + beta * dstc0 [x_gid ];
957
977
dstc1 [x_gid ] = alpha * work1 [0 ] + beta * dstc1 [x_gid ];
978
+ #endif
958
979
}
959
980
}
960
981
}
@@ -1058,10 +1079,17 @@ void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)(
1058
1079
if (lid == 0 ) {
1059
1080
#pragma unroll
1060
1081
for (int j = 0 ; j < rows ; ++ j ) {
1082
+ #ifdef ZERO_BETA
1083
+ dstc0 [(x_gid * 4 + j )] = alpha * work_each0 [j ];
1084
+ dstc1 [(x_gid * 4 + j )] = alpha * work_each1 [j ];
1085
+ dstc2 [(x_gid * 4 + j )] = alpha * work_each2 [j ];
1086
+ dstc3 [(x_gid * 4 + j )] = alpha * work_each3 [j ];
1087
+ #else
1061
1088
dstc0 [(x_gid * 4 + j )] = alpha * work_each0 [j ] + beta * dstc0 [(x_gid * 4 + j )];
1062
1089
dstc1 [(x_gid * 4 + j )] = alpha * work_each1 [j ] + beta * dstc1 [(x_gid * 4 + j )];
1063
1090
dstc2 [(x_gid * 4 + j )] = alpha * work_each2 [j ] + beta * dstc2 [(x_gid * 4 + j )];
1064
1091
dstc3 [(x_gid * 4 + j )] = alpha * work_each3 [j ] + beta * dstc3 [(x_gid * 4 + j )];
1092
+ #endif
1065
1093
}
1066
1094
}
1067
1095
}
@@ -1179,10 +1207,17 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)(
1179
1207
}
1180
1208
1181
1209
if (lid == 0 ) {
1210
+ #ifdef ZERO_BETA
1211
+ dstc0 [x_gid ] = alpha * work0 [0 ];
1212
+ dstc1 [x_gid ] = alpha * work1 [0 ];
1213
+ dstc2 [x_gid ] = alpha * work2 [0 ];
1214
+ dstc3 [x_gid ] = alpha * work3 [0 ];
1215
+ #else
1182
1216
dstc0 [x_gid ] = alpha * work0 [0 ] + beta * dstc0 [x_gid ];
1183
1217
dstc1 [x_gid ] = alpha * work1 [0 ] + beta * dstc1 [x_gid ];
1184
1218
dstc2 [x_gid ] = alpha * work2 [0 ] + beta * dstc2 [x_gid ];
1185
1219
dstc3 [x_gid ] = alpha * work3 [0 ] + beta * dstc3 [x_gid ];
1220
+ #endif
1186
1221
}
1187
1222
}
1188
1223
}
@@ -1320,6 +1355,16 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(
1320
1355
}
1321
1356
1322
1357
if (lid == 0 ) {
1358
+ #ifdef ZERO_BETA
1359
+ dstc0 [x_gid ] = alpha * work0 [0 ];
1360
+ dstc1 [x_gid ] = alpha * work1 [0 ];
1361
+ dstc2 [x_gid ] = alpha * work2 [0 ];
1362
+ dstc3 [x_gid ] = alpha * work3 [0 ];
1363
+ dstc4 [x_gid ] = alpha * work4 [0 ];
1364
+ dstc5 [x_gid ] = alpha * work5 [0 ];
1365
+ dstc6 [x_gid ] = alpha * work6 [0 ];
1366
+ dstc7 [x_gid ] = alpha * work7 [0 ];
1367
+ #else
1323
1368
dstc0 [x_gid ] = alpha * work0 [0 ] + beta * dstc0 [x_gid ];
1324
1369
dstc1 [x_gid ] = alpha * work1 [0 ] + beta * dstc1 [x_gid ];
1325
1370
dstc2 [x_gid ] = alpha * work2 [0 ] + beta * dstc2 [x_gid ];
@@ -1328,6 +1373,7 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(
1328
1373
dstc5 [x_gid ] = alpha * work5 [0 ] + beta * dstc5 [x_gid ];
1329
1374
dstc6 [x_gid ] = alpha * work6 [0 ] + beta * dstc6 [x_gid ];
1330
1375
dstc7 [x_gid ] = alpha * work7 [0 ] + beta * dstc7 [x_gid ];
1376
+ #endif
1331
1377
}
1332
1378
}
1333
1379
#undef SLM_SIZE
0 commit comments