Skip to content

Commit 0715985

Browse files
committed
* Performance fixes: minimized branch divergence, uses collectives to
eliminate redundant calculation, macros removed. * Kernel shared memory size check * Updates test-backend-ops to support graphs for performance measurement.
1 parent 720b483 commit 0715985

File tree

2 files changed

+140
-135
lines changed

2 files changed

+140
-135
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,29 +1011,35 @@ class vk_perf_logger {
10111011
void print_timings() {
10121012
if(timings.empty()){
10131013
return;
1014-
}
1014+
}
1015+
uint64_t total_all_op_times = 0;
10151016
std::cerr << "----------------\nVulkan Timings:" << std::endl;
10161017
for (const auto& t : timings) {
1017-
uint64_t total = 0;
1018+
uint64_t total_op_times = 0;
10181019
for (const auto& time : t.second) {
1019-
total += time;
1020+
total_op_times += time;
10201021
}
1021-
std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us";
1022+
std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0) << " us";
10221023

10231024
// If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
10241025
auto it = flops.find(t.first);
10251026
if(it != flops.end() && (it->second).size() == t.second.size()){
1026-
uint64_t total_nflops = 0;
1027+
uint64_t total_op_flops = 0;
10271028
for(const auto& elem : it->second){
1028-
total_nflops += elem;
1029+
total_op_flops += elem;
10291030
}
1030-
std::cout << " (" << (double(total_nflops)/(1000.0*1000.0*1000.0)) / (double(total)/(1000.0*1000.0*1000.0)) << " GFLOPS/s)";
1031+
std::cerr << " (" << (double(total_op_flops)/(1000.0*1000.0*1000.0)) / (double(total_op_times)/(1000.0*1000.0*1000.0)) << " GFLOPS/s)";
10311032
}
10321033

1034+
total_all_op_times += total_op_times;
10331035

10341036
std::cerr << std::endl;
10351037
}
10361038

1039+
if(timings.size() > 0){
1040+
std::cerr << "Total time: " << total_all_op_times/1000.0 << " us." << std::endl;
1041+
}
1042+
10371043
timings.clear();
10381044
flops.clear();
10391045
}
@@ -1072,6 +1078,7 @@ class vk_perf_logger {
10721078
uint64_t size_K = Cin*KW*KH;
10731079
uint64_t size_N = N*OW*OH;
10741080
uint64_t n_flops = size_M*size_N*(size_K+(size_K-1));
1081+
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) + ", N=N*OW*OH=" + std::to_string(size_N);
10751082
flops[name].push_back(n_flops);
10761083
timings[name].push_back(time);
10771084
return;
@@ -3026,7 +3033,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
30263033

30273034
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
30283035

3029-
ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {128 /* equal to BS_K in the shader */, 128 /* equal to BS_NPQ in the shader */, 1}, {}, 1);
3036+
// conv2d
3037+
uint32_t conv2d_WG_SIZE = 256;
3038+
uint32_t conv2d_BS_K = 128;
3039+
uint32_t conv2d_BS_CRS = 16;
3040+
uint32_t conv2d_BS_NPQ = 128;
3041+
uint32_t conv2d_TS_K = 8;
3042+
uint32_t conv2d_shmem_req = (conv2d_BS_K*(conv2d_BS_CRS+1) + conv2d_BS_CRS*(conv2d_BS_NPQ+1))*sizeof(float);
3043+
if(device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req){
3044+
conv2d_BS_CRS = 8;
3045+
conv2d_TS_K = 8;
3046+
}
3047+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K}, 1);
30303048

30313049
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
30323050
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -10200,6 +10218,11 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1020010218
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
1020110219
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
1020210220
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
10221+
}else if(cgraph->nodes[i]->op == GGML_OP_CONV_2D){
10222+
// Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
10223+
auto CRS_size = cgraph->nodes[i]->src[0]->ne[0]*cgraph->nodes[i]->src[0]->ne[1]*cgraph->nodes[i]->src[0]->ne[2];
10224+
auto NPQ_size = cgraph->nodes[i]->ne[0]*cgraph->nodes[i]->ne[1]*cgraph->nodes[i]->ne[3];
10225+
total_mat_mul_bytes += NPQ_size*CRS_size*ggml_type_size(cgraph->nodes[i]->type);
1020310226
}
1020410227
i += ctx->num_additional_fused_ops;
1020510228
ctx->num_additional_fused_ops = 0;
Lines changed: 109 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
#version 450
22

3-
#extension GL_EXT_control_flow_attributes : enable
3+
#define USE_COLLECTIVES
4+
5+
#ifdef USE_COLLECTIVES
6+
#extension GL_KHR_shader_subgroup_shuffle: enable
7+
#endif
48

59
#include "types.comp"
610

11+
// Make spec constant
12+
#define SHMEM_PAD 0
13+
714
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
815
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; // src0 - kernel: [KW, KH, Cin, Cout]
916
layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; // src1 - input: [W, H, Cin, N] -- channel_first format
@@ -45,12 +52,16 @@ layout (push_constant) uniform parameter {
4552
uint32_t nb3;
4653
} p;
4754

48-
#define WG_SIZE 256
49-
50-
layout(local_size_x = WG_SIZE, local_size_y = 1, local_size_z = 1) in;
55+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
56+
// Blocktile sizes
57+
layout(constant_id = 1) const uint BS_K = 128;
58+
layout(constant_id = 2) const uint BS_CRS = 16;
59+
layout(constant_id = 3) const uint BS_NPQ = 128;
60+
// Thread-tile sizes
61+
layout(constant_id = 4) const uint TS_K = 8;
5162

5263
uint32_t tid = gl_LocalInvocationID.x;
53-
const uint32_t bs = gl_WorkGroupSize.x;
64+
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
5465

5566
uint splitWork(uint work_size, uint block_size){
5667
return (block_size + work_size -1) / block_size;
@@ -62,16 +73,11 @@ uint32_t NPQ = p.N*p.OH*p.OW;
6273

6374
uint32_t n_elems_out = K*NPQ;
6475

65-
// Blocktile sizes
66-
const uint32_t BS_K = 128;
67-
const uint32_t BS_CRS = 16;
68-
const uint32_t BS_NPQ = 128;
69-
7076
// Number of blocktiles per input
7177
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
7278

73-
const uint32_t Ash_stride = BS_CRS+1;
74-
const uint32_t Bsh_stride = BS_NPQ+1;
79+
const uint32_t Ash_stride = BS_CRS+SHMEM_PAD;
80+
const uint32_t Bsh_stride = BS_NPQ+SHMEM_PAD;
7581

7682
const uint32_t Ash_numel = BS_K*BS_CRS;
7783
const uint32_t Bsh_numel = BS_CRS*BS_NPQ;
@@ -83,7 +89,6 @@ shared float Ash[Ash_len]; // K x CRS
8389
shared float Bsh[Bsh_len]; // CRS x NPQ
8490

8591
// Threadtile sizes
86-
const uint32_t TS_K = 16;
8792
const uint32_t TS_NPQ = BS_K*BS_NPQ / WG_SIZE / TS_K;
8893

8994
// Number of threadtiles per blocktile
@@ -111,134 +116,111 @@ uint32_t T_x = tid % NT_NPQ;
111116

112117
uint32_t Ar = tid / BS_CRS;
113118
uint32_t Ac = tid % BS_CRS;
114-
uint32_t ArpWg = WG_SIZE / BS_CRS;
119+
const uint32_t ArpWg = WG_SIZE / BS_CRS;
115120

116121
uint32_t Br = tid / BS_NPQ;
117122
uint32_t Bc = tid % BS_NPQ;
118-
uint32_t BrpWg = WG_SIZE / BS_NPQ;
123+
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
119124

120-
void initReg(){
125+
void main(){\
121126
for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
122127
for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
123128
regC[T_ly][T_lx] = 0.0;
124129
}
125130
}
126-
}
127-
128-
void outProdReg(){
129-
for(uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++){
130-
for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
131-
regA[T_ly] = Ash[(T_y*TS_K + T_ly)*Ash_stride + CRS_lidx];
131+
/* Advance block in CRS dim */\
132+
for(uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++){
133+
#ifdef USE_COLLECTIVES
134+
uint32_t cached_CRS_idx = B_idx_CRS*BS_CRS + gl_SubgroupInvocationID;
135+
uint32_t cached_Cin_idx = cached_CRS_idx / (p.KW*p.KH);
136+
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx*p.KW*p.KH);
137+
uint32_t cached_KH_idx = cached_CRS_remainder / p.KW;
138+
uint32_t cached_KW_idx = cached_CRS_remainder - cached_KH_idx*p.KW;
139+
140+
uint32_t CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
141+
uint32_t Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
142+
uint32_t KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
143+
uint32_t KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
144+
#else
145+
uint32_t CRS_idx_a = B_idx_CRS*BS_CRS + Ac; // Global CRS_idx_a (column index of A)
146+
uint32_t Cin_idx_a = CRS_idx_a / (p.KW*p.KH);
147+
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a*p.KW*p.KH;
148+
uint32_t KH_idx_a = CRS_remainder / p.KW;
149+
uint32_t KW_idx_a = CRS_remainder - KH_idx_a*p.KW;
150+
#endif
151+
152+
/* Load kernel to A_block: (BS_K x BS_CRS)*/
153+
for(uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg){
154+
uint32_t B_ly = r_offset + Ar;
155+
uint32_t B_lx = Ac;
156+
uint32_t K_idx = B_idx_K*BS_K + B_ly; /* Global K_idx (row index of A)*/
157+
uint32_t knl_idx = min(KW_idx_a + KH_idx_a*p.nb01 + Cin_idx_a*p.nb02 + K_idx*p.nb03, K*CRS-1);
158+
float val = knl_data[knl_idx];
159+
if(K_idx >= K || CRS_idx_a >= CRS){
160+
val = 0.0;
161+
}
162+
Ash[B_ly * Ash_stride + B_lx] = val;
132163
}
133-
for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
134-
regB[T_lx] = Bsh[CRS_lidx*Bsh_stride + T_x*TS_NPQ+T_lx];
164+
/* Load input to B_block: (BS_CRS x BS_NPQ) */
165+
for(uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg){
166+
uint32_t B_ly = r_offset + Br; /* Row index of B block */
167+
uint32_t B_lx = Bc;
168+
uint32_t NPQ_idx = B_idx_NPQ*BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
169+
uint32_t N_idx = NPQ_idx / (p.OH*p.OW);
170+
uint32_t NPQ_remainder = NPQ_idx - N_idx*p.OH*p.OW;
171+
uint32_t OH_idx = NPQ_remainder / p.OW;
172+
uint32_t OW_idx = NPQ_remainder - OH_idx*p.OW;
173+
174+
#ifdef USE_COLLECTIVES
175+
uint32_t CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
176+
uint32_t Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
177+
uint32_t KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
178+
uint32_t KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
179+
#else
180+
uint32_t CRS_idx_b = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */
181+
uint32_t Cin_idx_b = CRS_idx_b / (p.KW*p.KH);
182+
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b*p.KW*p.KH;
183+
uint32_t KH_idx_b = CRS_remainder / p.KW;
184+
uint32_t KW_idx_b = CRS_remainder - KH_idx_b*p.KW;
185+
#endif
186+
187+
uint32_t H_idx = OH_idx*p.s1 + KH_idx_b*p.d1 - p.p1;
188+
uint32_t W_idx = OW_idx*p.s0 + KW_idx_b*p.d0 - p.p0;
189+
uint32_t src_idx = min(max(W_idx + H_idx*p.nb11 + Cin_idx_b*p.nb12 + N_idx*p.nb13, 0), p.Cin*p.N*p.W*p.H-1);
190+
float val = src_data[src_idx];
191+
if(CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W){
192+
val = 0.0;
193+
}
194+
Bsh[B_ly * Bsh_stride + B_lx] = val;
135195
}
136-
for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
196+
barrier();
197+
for(uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++){
198+
for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
199+
regA[T_ly] = Ash[(T_y*TS_K + T_ly)*Ash_stride + CRS_lidx];
200+
}
137201
for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
138-
regC[T_ly][T_lx] += regA[T_ly] * regB[T_lx];
202+
regB[T_lx] = Bsh[CRS_lidx*Bsh_stride + T_x*TS_NPQ+T_lx];
203+
}
204+
for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
205+
for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
206+
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
207+
}
139208
}
140209
}
210+
barrier();
141211
}
142-
}
143-
144-
// Generate different functions for computing the sides.
145-
146-
#define NOOP()
147-
148-
#define DEF_BOUNDARY_CONDITION_A_IF()\
149-
if(K_idx < K && CRS_idx < CRS){
150-
151-
#define DEF_BOUNDARY_CONDITION_A_ELSE()\
152-
}else{\
153-
Ash[B_ly * Ash_stride + B_lx] = 0.0;\
154-
}
155-
156-
#define DEF_BOUNDARY_CONDITION_B_IF()\
157-
if(CRS_idx < CRS && NPQ_idx < NPQ){
158-
159-
#define DEF_BOUNDARY_CONDITION_B_ELSE()\
160-
}else{\
161-
Bsh[B_ly * Bsh_stride + B_lx] = 0.0;\
162-
}
163-
164-
#define MAIN_LOOP(FUNC_NAME_SUFFIX, BOUNDARY_CONDITION_A_IF, BOUNDARY_CONDITION_A_ELSE, BOUNDARY_CONDITION_B_IF, BOUNDARY_CONDITION_B_ELSE)\
165-
void mainLoop ## FUNC_NAME_SUFFIX(){\
166-
initReg();\
167-
/* Advance block in CRS dim */\
168-
for(uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++){\
169-
/* Load kernel to A_block: (BS_K x BS_CRS)*/\
170-
for(uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg){\
171-
uint32_t B_ly = r_offset + Ar;\
172-
uint32_t B_lx = Ac;\
173-
uint32_t K_idx = B_idx_K*BS_K + B_ly; /* Global K_idx (row index of A)*/\
174-
uint32_t CRS_idx = B_idx_CRS*BS_CRS + B_lx; /* Global CRS_idx (column index of A)*/\
175-
BOUNDARY_CONDITION_A_IF()\
176-
uint32_t Cin_idx = CRS_idx / (p.KW*p.KH);\
177-
uint32_t KH_idx = (CRS_idx - Cin_idx*p.KW*p.KH) / p.KW;\
178-
uint32_t KW_idx = CRS_idx - Cin_idx*p.KW*p.KH - KH_idx*p.KW;\
179-
uint32_t knl_idx = KW_idx + KH_idx*p.nb01 + Cin_idx*p.nb02 + K_idx*p.nb03;\
180-
Ash[B_ly * Ash_stride + B_lx] = knl_data[knl_idx];\
181-
BOUNDARY_CONDITION_A_ELSE()\
182-
}\
183-
barrier();\
184-
/* Load input to B_block: (BS_CRS x BS_NPQ) */\
185-
for(uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg){\
186-
uint32_t B_ly = r_offset + Br; /* Row index of B block */\
187-
uint32_t B_lx = Bc; /* Column index of B block */\
188-
uint32_t CRS_idx = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */\
189-
uint32_t NPQ_idx = B_idx_NPQ*BS_NPQ + B_lx; /* Global NPQ index (column index of B) */\
190-
BOUNDARY_CONDITION_B_IF()\
191-
uint32_t Cin_idx = CRS_idx / (p.KW*p.KH);\
192-
uint32_t KH_idx = (CRS_idx - Cin_idx*p.KW*p.KH) / p.KW;\
193-
uint32_t KW_idx = CRS_idx - Cin_idx*p.KW*p.KH - KH_idx*p.KW;\
194-
uint32_t N_idx = NPQ_idx / (p.OH*p.OW);\
195-
uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;\
196-
uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;\
197-
uint32_t H_idx = OH_idx*p.s1 + KH_idx*p.d1 - p.p1;\
198-
uint32_t W_idx = OW_idx*p.s0 + KW_idx*p.d0 - p.p0;\
199-
if(H_idx >= 0 && H_idx < p.H && W_idx >= 0 && W_idx < p.W){\
200-
uint32_t src_idx = W_idx + H_idx*p.nb11 + Cin_idx*p.nb12 + N_idx*p.nb13;\
201-
Bsh[B_ly * Bsh_stride + B_lx] = src_data[src_idx];\
202-
}else{\
203-
Bsh[B_ly * Bsh_stride + B_lx] = 0.0;\
204-
}\
205-
BOUNDARY_CONDITION_B_ELSE()\
206-
}\
207-
barrier();\
208-
outProdReg();\
209-
barrier();\
210-
}\
211-
/* Save C* */\
212-
for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){\
213-
for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){\
214-
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;\
215-
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;\
216-
if(K_idx < K && NPQ_idx < NPQ){\
217-
uint32_t N_idx = NPQ_idx / (p.OH*p.OW);\
218-
uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;\
219-
uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;\
220-
uint32_t dst_idx = OW_idx + OH_idx*p.nb1 + K_idx*p.nb2 + N_idx*p.nb3;\
221-
dst_data[dst_idx] = regC[T_ly][T_lx];\
222-
}\
223-
}\
224-
}\
225-
}
226-
227-
// Generates mainLoopBoundaryCheck
228-
MAIN_LOOP(BoundaryCheck,
229-
DEF_BOUNDARY_CONDITION_A_IF,
230-
DEF_BOUNDARY_CONDITION_A_ELSE,
231-
DEF_BOUNDARY_CONDITION_B_IF,
232-
DEF_BOUNDARY_CONDITION_B_ELSE)
233-
234-
// Generates mainLoopNoBoundaryCheck
235-
MAIN_LOOP(NoBoundaryCheck,
236-
NOOP, NOOP, NOOP, NOOP)
237-
238-
void main(){
239-
if(gl_WorkGroupID.x == gl_NumWorkGroups.x-1 || gl_WorkGroupID.y == gl_NumWorkGroups.y-1){
240-
mainLoopBoundaryCheck();
241-
}else{
242-
mainLoopNoBoundaryCheck();
212+
/* Save C* */
213+
for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
214+
for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
215+
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
216+
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
217+
uint32_t N_idx = NPQ_idx / (p.OH*p.OW);
218+
uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;
219+
uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;
220+
uint32_t dst_idx = OW_idx + OH_idx*p.nb1 + K_idx*p.nb2 + N_idx*p.nb3;
221+
if(K_idx < K && NPQ_idx < NPQ){
222+
dst_data[dst_idx] = regC[T_ly][T_lx];
223+
}
224+
}
243225
}
244226
}

0 commit comments

Comments
 (0)