Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 687f183

Browse files
Sven Verdoolaegeftynse
authored andcommitted
tc2halide: drop temporary hack fusing reduction init and update statements
Since the detection of reductions in halide2isl no longer depends on init statements, the hack is no longer needed. It would only work if the init and the update are the only two statements anyway.
1 parent a74c757 commit 687f183

File tree

3 files changed

+112
-69
lines changed

3 files changed

+112
-69
lines changed

tc/core/tc2halide.cc

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -834,38 +834,6 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
834834
s = tagReduction.mutate(s);
835835
}
836836

837-
// Temporary hack: Fuse reduction initializations into the
838-
// updates. Shouldn't matter because we're going to reschedule
839-
// everything anyway, but starting at this point is more similar to
840-
// the code this replaces. Only correct because we only currently
841-
// create update definitions when there are reductions.
842-
class FuseReductions : public IRMutator2 {
843-
using IRMutator2::visit;
844-
Stmt visit(const Block* op) override {
845-
const For* first = op->first.as<For>();
846-
const For* rest = op->rest.as<For>();
847-
if (first && rest && equal(first->min, rest->min) &&
848-
equal(first->extent, rest->extent) &&
849-
first->for_type == rest->for_type &&
850-
replace_all(first->name, ".s0.", ".s1.") == rest->name) {
851-
Stmt body = rest->body;
852-
body =
853-
substitute(rest->name, Variable::make(Int(32), first->name), body);
854-
body = mutate(Block::make(first->body, body));
855-
return For::make(
856-
first->name,
857-
first->min,
858-
first->extent,
859-
first->for_type,
860-
first->device_api,
861-
body);
862-
} else {
863-
return IRMutator2::visit(op);
864-
}
865-
}
866-
} fuser;
867-
s = fuser.mutate(s);
868-
869837
// Trim ProducerConsumer annotations. TC doesn't use them.
870838
class RemoveProducerConsumer : public IRMutator2 {
871839
using IRMutator2::visit;

test/test_core.cc

Lines changed: 108 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ struct GenericHalideCoreTest : public ::testing::Test {
7575
curPos = newPos;
7676
}
7777
}
78+
void CheckC(const std::string& tc, const std::string& expected) {
79+
std::istringstream stream(expected);
80+
std::string line;
81+
std::vector<std::string> split;
82+
while (std::getline(stream, line)) {
83+
// Skip lines containing (only) closing brace.
84+
if (line.find('}') == std::string::npos) {
85+
split.emplace_back(line);
86+
}
87+
}
88+
CheckC(tc, split);
89+
}
7890
};
7991

8092
TEST_F(GenericHalideCoreTest, TwoMatmul) {
@@ -86,18 +98,32 @@ def fun(float(M, K) I, float(K, N) W1, float(N, P) W2) -> (O1, O2) {
8698
)TC";
8799
CheckC(
88100
tc,
89-
{
90-
"for (int O1_s0_m = 0; O1_s0_m < M; O1_s0_m++) {",
91-
" for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {",
92-
" O1[O1_s0_m][O1_s0_n] = 0.000000f",
93-
" for (int O1_s1_r_k = 0; O1_s1_r_k < K; O1_s1_r_k++) {",
94-
" O1[O1_s0_m][O1_s0_n] = (O1[O1_s0_m][O1_s0_n] + (I[O1_s0_m][O1_s1_r_k]*W1[O1_s1_r_k][O1_s0_n]))",
95-
"for (int O2_s0_m = 0; O2_s0_m < M; O2_s0_m++) {",
96-
" for (int O2_s0_p = 0; O2_s0_p < P; O2_s0_p++) {",
97-
" O2[O2_s0_m][O2_s0_p] = 0.000000f",
98-
" for (int O2_s1_r_n = 0; O2_s1_r_n < N; O2_s1_r_n++) {",
99-
" O2[O2_s0_m][O2_s0_p] = (O2[O2_s0_m][O2_s0_p] + (O1[O2_s0_m][O2_s1_r_n]*W2[O2_s1_r_n][O2_s0_p]))",
100-
});
101+
R"C(
102+
for (int O1_s0_m = 0; O1_s0_m < M; O1_s0_m++) {
103+
for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {
104+
O1[O1_s0_m][O1_s0_n] = 0.000000f;
105+
}
106+
}
107+
for (int O1_s1_m = 0; O1_s1_m < M; O1_s1_m++) {
108+
for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) {
109+
for (int O1_s1_r_k = 0; O1_s1_r_k < K; O1_s1_r_k++) {
110+
O1[O1_s1_m][O1_s1_n] = (O1[O1_s1_m][O1_s1_n] + (I[O1_s1_m][O1_s1_r_k]*W1[O1_s1_r_k][O1_s1_n]));
111+
}
112+
}
113+
}
114+
for (int O2_s0_m = 0; O2_s0_m < M; O2_s0_m++) {
115+
for (int O2_s0_p = 0; O2_s0_p < P; O2_s0_p++) {
116+
O2[O2_s0_m][O2_s0_p] = 0.000000f;
117+
}
118+
}
119+
for (int O2_s1_m = 0; O2_s1_m < M; O2_s1_m++) {
120+
for (int O2_s1_p = 0; O2_s1_p < P; O2_s1_p++) {
121+
for (int O2_s1_r_n = 0; O2_s1_r_n < N; O2_s1_r_n++) {
122+
O2[O2_s1_m][O2_s1_p] = (O2[O2_s1_m][O2_s1_p] + (O1[O2_s1_m][O2_s1_r_n]*W2[O2_s1_r_n][O2_s1_p]));
123+
}
124+
}
125+
}
126+
)C");
101127
}
102128

103129
TEST_F(GenericHalideCoreTest, Convolution) {
@@ -108,15 +134,32 @@ def fun(float(N, C, H, W) I1, float(C, F, KH, KW) W1) -> (O1) {
108134
)TC";
109135
CheckC(
110136
tc,
111-
{"for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {",
112-
" for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) {",
113-
" for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {",
114-
" for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {",
115-
" O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f",
116-
" for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {",
117-
" for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {",
118-
" for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {",
119-
" O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = (O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] + (I1[O1_s0_n][O1_s1_r_c][(O1_s0_h + O1_s1_r_kh)][(O1_s0_w + O1_s1_r_kw)]*W1[O1_s1_r_c][O1_s0_f][O1_s1_r_kh][O1_s1_r_kw]))"});
137+
R"C(
138+
for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {
139+
for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) {
140+
for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {
141+
for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {
142+
O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f;
143+
}
144+
}
145+
}
146+
}
147+
for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) {
148+
for (int O1_s1_f = 0; O1_s1_f < F; O1_s1_f++) {
149+
for (int O1_s1_h = 0; O1_s1_h < ((H - KH) + 1); O1_s1_h++) {
150+
for (int O1_s1_w = 0; O1_s1_w < ((W - KW) + 1); O1_s1_w++) {
151+
for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {
152+
for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {
153+
for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {
154+
O1[O1_s1_n][O1_s1_f][O1_s1_h][O1_s1_w] = (O1[O1_s1_n][O1_s1_f][O1_s1_h][O1_s1_w] + (I1[O1_s1_n][O1_s1_r_c][(O1_s1_h + O1_s1_r_kh)][(O1_s1_w + O1_s1_r_kw)]*W1[O1_s1_r_c][O1_s1_f][O1_s1_r_kh][O1_s1_r_kw]));
155+
}
156+
}
157+
}
158+
}
159+
}
160+
}
161+
}
162+
)C");
120163
}
121164

122165
TEST_F(GenericHalideCoreTest, Copy) {
@@ -136,27 +179,55 @@ def fun(float(N, G, C, H, W) I1, float(G, C, F, KH, KW) W1) -> (O1) {
136179
)TC";
137180
CheckC(
138181
tc,
139-
{"for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {",
140-
" for (int O1_s0_g = 0; O1_s0_g < G; O1_s0_g++) {",
141-
" for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) {",
142-
" for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {",
143-
" for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {",
144-
" O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f",
145-
" for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {",
146-
" for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {",
147-
" for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {",
148-
" O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = (O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] + (I1[O1_s0_n][O1_s0_g][O1_s1_r_c][(O1_s0_h + O1_s1_r_kh)][(O1_s0_w + O1_s1_r_kw)]*W1[O1_s0_g][O1_s1_r_c][O1_s0_f][O1_s1_r_kh][O1_s1_r_kw]))"});
182+
R"C(
183+
for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {
184+
for (int O1_s0_g = 0; O1_s0_g < G; O1_s0_g++) {
185+
for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) {
186+
for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {
187+
for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {
188+
O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f;
189+
}
190+
}
191+
}
192+
}
193+
}
194+
for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) {
195+
for (int O1_s1_g = 0; O1_s1_g < G; O1_s1_g++) {
196+
for (int O1_s1_f = 0; O1_s1_f < F; O1_s1_f++) {
197+
for (int O1_s1_h = 0; O1_s1_h < ((H - KH) + 1); O1_s1_h++) {
198+
for (int O1_s1_w = 0; O1_s1_w < ((W - KW) + 1); O1_s1_w++) {
199+
for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {
200+
for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {
201+
for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {
202+
O1[O1_s1_n][O1_s1_g][O1_s1_f][O1_s1_h][O1_s1_w] = (O1[O1_s1_n][O1_s1_g][O1_s1_f][O1_s1_h][O1_s1_w] + (I1[O1_s1_n][O1_s1_g][O1_s1_r_c][(O1_s1_h + O1_s1_r_kh)][(O1_s1_w + O1_s1_r_kw)]*W1[O1_s1_g][O1_s1_r_c][O1_s1_f][O1_s1_r_kh][O1_s1_r_kw]));
203+
}
204+
}
205+
}
206+
}
207+
}
208+
}
209+
}
210+
}
211+
)C");
149212
}
150213

151214
TEST_F(GenericHalideCoreTest, Matmul) {
152215
CheckC(
153216
makeMatmulTc(false, false),
154-
std::vector<std::string>{
155-
"for (int O_s0_i = 0; O_s0_i < N; O_s0_i++) {",
156-
" for (int O_s0_j = 0; O_s0_j < M; O_s0_j++) {",
157-
" O[O_s0_i][O_s0_j] = 0.000000f;",
158-
" for (int O_s1_k = 0; O_s1_k < K; O_s1_k++) {",
159-
" O[O_s0_i][O_s0_j] = (O[O_s0_i][O_s0_j] + (A[O_s0_i][O_s1_k]*B[O_s1_k][O_s0_j]));"});
217+
R"C(
218+
for (int O_s0_i = 0; O_s0_i < N; O_s0_i++) {
219+
for (int O_s0_j = 0; O_s0_j < M; O_s0_j++) {
220+
O[O_s0_i][O_s0_j] = 0.000000f;
221+
}
222+
}
223+
for (int O_s1_i = 0; O_s1_i < N; O_s1_i++) {
224+
for (int O_s1_j = 0; O_s1_j < M; O_s1_j++) {
225+
for (int O_s1_k = 0; O_s1_k < K; O_s1_k++) {
226+
O[O_s1_i][O_s1_j] = (O[O_s1_i][O_s1_j] + (A[O_s1_i][O_s1_k]*B[O_s1_k][O_s1_j]));
227+
}
228+
}
229+
}
230+
)C");
160231
}
161232

162233
using namespace isl::with_exceptions;

test/test_cuda_mapper.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,10 @@ def fun(float(N, N, N, N) A, float(N, N) B, float(N, N) C, float(N, N) D)
416416
for (int c0 = 0; c0 < N; c0 += 1) {
417417
for (int c1 = 0; c1 < N; c1 += 1) {
418418
O1[c0][c1] = 0.000000f;
419+
}
420+
}
421+
for (int c0 = 0; c0 < N; c0 += 1) {
422+
for (int c1 = 0; c1 < N; c1 += 1) {
419423
for (int c2 = 0; c2 < N; c2 += 1) {
420424
for (int c3 = 0; c3 < N; c3 += 1) {
421425
O1[c0][c1] = (O1[c0][c1] + (A[c0][c1][c2][c3]*B[c0][c1]));

0 commit comments

Comments
 (0)