1
1
#version 450
2
2
3
- #extension GL_EXT_control_flow_attributes : enable
3
+ #define USE_COLLECTIVES
4
+
5
+ #ifdef USE_COLLECTIVES
6
+ #extension GL_KHR_shader_subgroup_shuffle: enable
7
+ #endif
4
8
5
9
#include "types.comp"
6
10
11
+ // Make spec constant
12
+ #define SHMEM_PAD 0
13
+
7
14
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
8
15
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; // src0 - kernel: [KW, KH, Cin, Cout]
9
16
layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; // src1 - input: [W, H, Cin, N] -- channel_first format
@@ -45,12 +52,16 @@ layout (push_constant) uniform parameter {
45
52
uint32_t nb3;
46
53
} p;
47
54
48
- #define WG_SIZE 256
49
-
50
- layout(local_size_x = WG_SIZE, local_size_y = 1, local_size_z = 1) in;
55
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
56
+ // Blocktile sizes
57
+ layout(constant_id = 1) const uint BS_K = 128;
58
+ layout(constant_id = 2) const uint BS_CRS = 16;
59
+ layout(constant_id = 3) const uint BS_NPQ = 128;
60
+ // Thread-tile sizes
61
+ layout(constant_id = 4) const uint TS_K = 8;
51
62
52
63
uint32_t tid = gl_LocalInvocationID.x;
53
- const uint32_t bs = gl_WorkGroupSize.x;
64
+ const uint32_t WG_SIZE = gl_WorkGroupSize.x;
54
65
55
66
uint splitWork(uint work_size, uint block_size){
56
67
return (block_size + work_size -1) / block_size;
@@ -62,16 +73,11 @@ uint32_t NPQ = p.N*p.OH*p.OW;
62
73
63
74
uint32_t n_elems_out = K*NPQ;
64
75
65
- // Blocktile sizes
66
- const uint32_t BS_K = 128;
67
- const uint32_t BS_CRS = 16;
68
- const uint32_t BS_NPQ = 128;
69
-
70
76
// Number of blocktiles per input
71
77
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
72
78
73
- const uint32_t Ash_stride = BS_CRS+1 ;
74
- const uint32_t Bsh_stride = BS_NPQ+1 ;
79
+ const uint32_t Ash_stride = BS_CRS+SHMEM_PAD ;
80
+ const uint32_t Bsh_stride = BS_NPQ+SHMEM_PAD ;
75
81
76
82
const uint32_t Ash_numel = BS_K*BS_CRS;
77
83
const uint32_t Bsh_numel = BS_CRS*BS_NPQ;
@@ -83,7 +89,6 @@ shared float Ash[Ash_len]; // K x CRS
83
89
shared float Bsh[Bsh_len]; // CRS x NPQ
84
90
85
91
// Threadtile sizes
86
- const uint32_t TS_K = 16;
87
92
const uint32_t TS_NPQ = BS_K*BS_NPQ / WG_SIZE / TS_K;
88
93
89
94
// Number of threadtiles per blocktile
@@ -111,134 +116,111 @@ uint32_t T_x = tid % NT_NPQ;
111
116
112
117
uint32_t Ar = tid / BS_CRS;
113
118
uint32_t Ac = tid % BS_CRS;
114
- uint32_t ArpWg = WG_SIZE / BS_CRS;
119
+ const uint32_t ArpWg = WG_SIZE / BS_CRS;
115
120
116
121
uint32_t Br = tid / BS_NPQ;
117
122
uint32_t Bc = tid % BS_NPQ;
118
- uint32_t BrpWg = WG_SIZE / BS_NPQ;
123
+ const uint32_t BrpWg = WG_SIZE / BS_NPQ;
119
124
120
- void initReg (){
125
+ void main (){\
121
126
for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
122
127
for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
123
128
regC[T_ly][T_lx] = 0.0;
124
129
}
125
130
}
126
- }
127
-
128
- void outProdReg(){
129
- for(uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++){
130
- for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
131
- regA[T_ly] = Ash[(T_y*TS_K + T_ly)*Ash_stride + CRS_lidx];
131
+ /* Advance block in CRS dim */\
132
+ for(uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++){
133
+ #ifdef USE_COLLECTIVES
134
+ uint32_t cached_CRS_idx = B_idx_CRS*BS_CRS + gl_SubgroupInvocationID;
135
+ uint32_t cached_Cin_idx = cached_CRS_idx / (p.KW*p.KH);
136
+ uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx*p.KW*p.KH);
137
+ uint32_t cached_KH_idx = cached_CRS_remainder / p.KW;
138
+ uint32_t cached_KW_idx = cached_CRS_remainder - cached_KH_idx*p.KW;
139
+
140
+ uint32_t CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
141
+ uint32_t Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
142
+ uint32_t KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
143
+ uint32_t KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
144
+ #else
145
+ uint32_t CRS_idx_a = B_idx_CRS*BS_CRS + Ac; // Global CRS_idx_a (column index of A)
146
+ uint32_t Cin_idx_a = CRS_idx_a / (p.KW*p.KH);
147
+ uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a*p.KW*p.KH;
148
+ uint32_t KH_idx_a = CRS_remainder / p.KW;
149
+ uint32_t KW_idx_a = CRS_remainder - KH_idx_a*p.KW;
150
+ #endif
151
+
152
+ /* Load kernel to A_block: (BS_K x BS_CRS)*/
153
+ for(uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg){
154
+ uint32_t B_ly = r_offset + Ar;
155
+ uint32_t B_lx = Ac;
156
+ uint32_t K_idx = B_idx_K*BS_K + B_ly; /* Global K_idx (row index of A)*/
157
+ uint32_t knl_idx = min(KW_idx_a + KH_idx_a*p.nb01 + Cin_idx_a*p.nb02 + K_idx*p.nb03, K*CRS-1);
158
+ float val = knl_data[knl_idx];
159
+ if(K_idx >= K || CRS_idx_a >= CRS){
160
+ val = 0.0;
161
+ }
162
+ Ash[B_ly * Ash_stride + B_lx] = val;
132
163
}
133
- for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
134
- regB[T_lx] = Bsh[CRS_lidx*Bsh_stride + T_x*TS_NPQ+T_lx];
164
+ /* Load input to B_block: (BS_CRS x BS_NPQ) */
165
+ for(uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg){
166
+ uint32_t B_ly = r_offset + Br; /* Row index of B block */
167
+ uint32_t B_lx = Bc;
168
+ uint32_t NPQ_idx = B_idx_NPQ*BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
169
+ uint32_t N_idx = NPQ_idx / (p.OH*p.OW);
170
+ uint32_t NPQ_remainder = NPQ_idx - N_idx*p.OH*p.OW;
171
+ uint32_t OH_idx = NPQ_remainder / p.OW;
172
+ uint32_t OW_idx = NPQ_remainder - OH_idx*p.OW;
173
+
174
+ #ifdef USE_COLLECTIVES
175
+ uint32_t CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
176
+ uint32_t Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
177
+ uint32_t KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
178
+ uint32_t KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
179
+ #else
180
+ uint32_t CRS_idx_b = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */
181
+ uint32_t Cin_idx_b = CRS_idx_b / (p.KW*p.KH);
182
+ uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b*p.KW*p.KH;
183
+ uint32_t KH_idx_b = CRS_remainder / p.KW;
184
+ uint32_t KW_idx_b = CRS_remainder - KH_idx_b*p.KW;
185
+ #endif
186
+
187
+ uint32_t H_idx = OH_idx*p.s1 + KH_idx_b*p.d1 - p.p1;
188
+ uint32_t W_idx = OW_idx*p.s0 + KW_idx_b*p.d0 - p.p0;
189
+ uint32_t src_idx = min(max(W_idx + H_idx*p.nb11 + Cin_idx_b*p.nb12 + N_idx*p.nb13, 0), p.Cin*p.N*p.W*p.H-1);
190
+ float val = src_data[src_idx];
191
+ if(CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W){
192
+ val = 0.0;
193
+ }
194
+ Bsh[B_ly * Bsh_stride + B_lx] = val;
135
195
}
136
- for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
196
+ barrier();
197
+ for(uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++){
198
+ for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
199
+ regA[T_ly] = Ash[(T_y*TS_K + T_ly)*Ash_stride + CRS_lidx];
200
+ }
137
201
for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
138
- regC[T_ly][T_lx] += regA[T_ly] * regB[T_lx];
202
+ regB[T_lx] = Bsh[CRS_lidx*Bsh_stride + T_x*TS_NPQ+T_lx];
203
+ }
204
+ for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
205
+ for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
206
+ regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
207
+ }
139
208
}
140
209
}
210
+ barrier();
141
211
}
142
- }
143
-
144
- // Generate different functions for computing the sides.
145
-
146
- #define NOOP()
147
-
148
- #define DEF_BOUNDARY_CONDITION_A_IF()\
149
- if(K_idx < K && CRS_idx < CRS){
150
-
151
- #define DEF_BOUNDARY_CONDITION_A_ELSE()\
152
- }else{\
153
- Ash[B_ly * Ash_stride + B_lx] = 0.0;\
154
- }
155
-
156
- #define DEF_BOUNDARY_CONDITION_B_IF()\
157
- if(CRS_idx < CRS && NPQ_idx < NPQ){
158
-
159
- #define DEF_BOUNDARY_CONDITION_B_ELSE()\
160
- }else{\
161
- Bsh[B_ly * Bsh_stride + B_lx] = 0.0;\
162
- }
163
-
164
- #define MAIN_LOOP(FUNC_NAME_SUFFIX, BOUNDARY_CONDITION_A_IF, BOUNDARY_CONDITION_A_ELSE, BOUNDARY_CONDITION_B_IF, BOUNDARY_CONDITION_B_ELSE)\
165
- void mainLoop ## FUNC_NAME_SUFFIX(){\
166
- initReg();\
167
- /* Advance block in CRS dim */\
168
- for(uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++){\
169
- /* Load kernel to A_block: (BS_K x BS_CRS)*/\
170
- for(uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg){\
171
- uint32_t B_ly = r_offset + Ar;\
172
- uint32_t B_lx = Ac;\
173
- uint32_t K_idx = B_idx_K*BS_K + B_ly; /* Global K_idx (row index of A)*/\
174
- uint32_t CRS_idx = B_idx_CRS*BS_CRS + B_lx; /* Global CRS_idx (column index of A)*/\
175
- BOUNDARY_CONDITION_A_IF()\
176
- uint32_t Cin_idx = CRS_idx / (p.KW*p.KH);\
177
- uint32_t KH_idx = (CRS_idx - Cin_idx*p.KW*p.KH) / p.KW;\
178
- uint32_t KW_idx = CRS_idx - Cin_idx*p.KW*p.KH - KH_idx*p.KW;\
179
- uint32_t knl_idx = KW_idx + KH_idx*p.nb01 + Cin_idx*p.nb02 + K_idx*p.nb03;\
180
- Ash[B_ly * Ash_stride + B_lx] = knl_data[knl_idx];\
181
- BOUNDARY_CONDITION_A_ELSE()\
182
- }\
183
- barrier();\
184
- /* Load input to B_block: (BS_CRS x BS_NPQ) */\
185
- for(uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg){\
186
- uint32_t B_ly = r_offset + Br; /* Row index of B block */\
187
- uint32_t B_lx = Bc; /* Column index of B block */\
188
- uint32_t CRS_idx = B_idx_CRS*BS_CRS + B_ly; /* Global CRS index (row index of B) */\
189
- uint32_t NPQ_idx = B_idx_NPQ*BS_NPQ + B_lx; /* Global NPQ index (column index of B) */\
190
- BOUNDARY_CONDITION_B_IF()\
191
- uint32_t Cin_idx = CRS_idx / (p.KW*p.KH);\
192
- uint32_t KH_idx = (CRS_idx - Cin_idx*p.KW*p.KH) / p.KW;\
193
- uint32_t KW_idx = CRS_idx - Cin_idx*p.KW*p.KH - KH_idx*p.KW;\
194
- uint32_t N_idx = NPQ_idx / (p.OH*p.OW);\
195
- uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;\
196
- uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;\
197
- uint32_t H_idx = OH_idx*p.s1 + KH_idx*p.d1 - p.p1;\
198
- uint32_t W_idx = OW_idx*p.s0 + KW_idx*p.d0 - p.p0;\
199
- if(H_idx >= 0 && H_idx < p.H && W_idx >= 0 && W_idx < p.W){\
200
- uint32_t src_idx = W_idx + H_idx*p.nb11 + Cin_idx*p.nb12 + N_idx*p.nb13;\
201
- Bsh[B_ly * Bsh_stride + B_lx] = src_data[src_idx];\
202
- }else{\
203
- Bsh[B_ly * Bsh_stride + B_lx] = 0.0;\
204
- }\
205
- BOUNDARY_CONDITION_B_ELSE()\
206
- }\
207
- barrier();\
208
- outProdReg();\
209
- barrier();\
210
- }\
211
- /* Save C* */\
212
- for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){\
213
- for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){\
214
- uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;\
215
- uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;\
216
- if(K_idx < K && NPQ_idx < NPQ){\
217
- uint32_t N_idx = NPQ_idx / (p.OH*p.OW);\
218
- uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;\
219
- uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;\
220
- uint32_t dst_idx = OW_idx + OH_idx*p.nb1 + K_idx*p.nb2 + N_idx*p.nb3;\
221
- dst_data[dst_idx] = regC[T_ly][T_lx];\
222
- }\
223
- }\
224
- }\
225
- }
226
-
227
- // Generates mainLoopBoundaryCheck
228
- MAIN_LOOP(BoundaryCheck,
229
- DEF_BOUNDARY_CONDITION_A_IF,
230
- DEF_BOUNDARY_CONDITION_A_ELSE,
231
- DEF_BOUNDARY_CONDITION_B_IF,
232
- DEF_BOUNDARY_CONDITION_B_ELSE)
233
-
234
- // Generates mainLoopNoBoundaryCheck
235
- MAIN_LOOP(NoBoundaryCheck,
236
- NOOP, NOOP, NOOP, NOOP)
237
-
238
- void main(){
239
- if(gl_WorkGroupID.x == gl_NumWorkGroups.x-1 || gl_WorkGroupID.y == gl_NumWorkGroups.y-1){
240
- mainLoopBoundaryCheck();
241
- }else{
242
- mainLoopNoBoundaryCheck();
212
+ /* Save C* */
213
+ for(uint32_t T_ly = 0; T_ly < TS_K; T_ly++){
214
+ for(uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++){
215
+ uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
216
+ uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
217
+ uint32_t N_idx = NPQ_idx / (p.OH*p.OW);
218
+ uint32_t OH_idx = (NPQ_idx - N_idx*p.OH*p.OW) / p.OW;
219
+ uint32_t OW_idx = NPQ_idx - N_idx*p.OH*p.OW - OH_idx*p.OW;
220
+ uint32_t dst_idx = OW_idx + OH_idx*p.nb1 + K_idx*p.nb2 + N_idx*p.nb3;
221
+ if(K_idx < K && NPQ_idx < NPQ){
222
+ dst_data[dst_idx] = regC[T_ly][T_lx];
223
+ }
224
+ }
243
225
}
244
226
}
0 commit comments