@@ -75,6 +75,18 @@ struct GenericHalideCoreTest : public ::testing::Test {
75
75
curPos = newPos;
76
76
}
77
77
}
78
+ void CheckC (const std::string& tc, const std::string& expected) {
79
+ std::istringstream stream (expected);
80
+ std::string line;
81
+ std::vector<std::string> split;
82
+ while (std::getline (stream, line)) {
83
+ // Skip lines containing (only) closing brace.
84
+ if (line.find (' }' ) == std::string::npos) {
85
+ split.emplace_back (line);
86
+ }
87
+ }
88
+ CheckC (tc, split);
89
+ }
78
90
};
79
91
80
92
TEST_F (GenericHalideCoreTest, TwoMatmul) {
@@ -86,18 +98,32 @@ def fun(float(M, K) I, float(K, N) W1, float(N, P) W2) -> (O1, O2) {
86
98
)TC" ;
87
99
CheckC (
88
100
tc,
89
- {
90
- " for (int O1_s0_m = 0; O1_s0_m < M; O1_s0_m++) {" ,
91
- " for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {" ,
92
- " O1[O1_s0_m][O1_s0_n] = 0.000000f" ,
93
- " for (int O1_s1_r_k = 0; O1_s1_r_k < K; O1_s1_r_k++) {" ,
94
- " O1[O1_s0_m][O1_s0_n] = (O1[O1_s0_m][O1_s0_n] + (I[O1_s0_m][O1_s1_r_k]*W1[O1_s1_r_k][O1_s0_n]))" ,
95
- " for (int O2_s0_m = 0; O2_s0_m < M; O2_s0_m++) {" ,
96
- " for (int O2_s0_p = 0; O2_s0_p < P; O2_s0_p++) {" ,
97
- " O2[O2_s0_m][O2_s0_p] = 0.000000f" ,
98
- " for (int O2_s1_r_n = 0; O2_s1_r_n < N; O2_s1_r_n++) {" ,
99
- " O2[O2_s0_m][O2_s0_p] = (O2[O2_s0_m][O2_s0_p] + (O1[O2_s0_m][O2_s1_r_n]*W2[O2_s1_r_n][O2_s0_p]))" ,
100
- });
101
+ R"C(
102
+ for (int O1_s0_m = 0; O1_s0_m < M; O1_s0_m++) {
103
+ for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {
104
+ O1[O1_s0_m][O1_s0_n] = 0.000000f;
105
+ }
106
+ }
107
+ for (int O1_s1_m = 0; O1_s1_m < M; O1_s1_m++) {
108
+ for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) {
109
+ for (int O1_s1_r_k = 0; O1_s1_r_k < K; O1_s1_r_k++) {
110
+ O1[O1_s1_m][O1_s1_n] = (O1[O1_s1_m][O1_s1_n] + (I[O1_s1_m][O1_s1_r_k]*W1[O1_s1_r_k][O1_s1_n]));
111
+ }
112
+ }
113
+ }
114
+ for (int O2_s0_m = 0; O2_s0_m < M; O2_s0_m++) {
115
+ for (int O2_s0_p = 0; O2_s0_p < P; O2_s0_p++) {
116
+ O2[O2_s0_m][O2_s0_p] = 0.000000f;
117
+ }
118
+ }
119
+ for (int O2_s1_m = 0; O2_s1_m < M; O2_s1_m++) {
120
+ for (int O2_s1_p = 0; O2_s1_p < P; O2_s1_p++) {
121
+ for (int O2_s1_r_n = 0; O2_s1_r_n < N; O2_s1_r_n++) {
122
+ O2[O2_s1_m][O2_s1_p] = (O2[O2_s1_m][O2_s1_p] + (O1[O2_s1_m][O2_s1_r_n]*W2[O2_s1_r_n][O2_s1_p]));
123
+ }
124
+ }
125
+ }
126
+ )C" );
101
127
}
102
128
103
129
TEST_F (GenericHalideCoreTest, Convolution) {
@@ -108,15 +134,32 @@ def fun(float(N, C, H, W) I1, float(C, F, KH, KW) W1) -> (O1) {
108
134
)TC" ;
109
135
CheckC (
110
136
tc,
111
- {" for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {" ,
112
- " for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) {" ,
113
- " for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {" ,
114
- " for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {" ,
115
- " O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f" ,
116
- " for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {" ,
117
- " for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {" ,
118
- " for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {" ,
119
- " O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = (O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] + (I1[O1_s0_n][O1_s1_r_c][(O1_s0_h + O1_s1_r_kh)][(O1_s0_w + O1_s1_r_kw)]*W1[O1_s1_r_c][O1_s0_f][O1_s1_r_kh][O1_s1_r_kw]))" });
137
+ R"C(
138
+ for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {
139
+ for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) {
140
+ for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {
141
+ for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {
142
+ O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f;
143
+ }
144
+ }
145
+ }
146
+ }
147
+ for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) {
148
+ for (int O1_s1_f = 0; O1_s1_f < F; O1_s1_f++) {
149
+ for (int O1_s1_h = 0; O1_s1_h < ((H - KH) + 1); O1_s1_h++) {
150
+ for (int O1_s1_w = 0; O1_s1_w < ((W - KW) + 1); O1_s1_w++) {
151
+ for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {
152
+ for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {
153
+ for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {
154
+ O1[O1_s1_n][O1_s1_f][O1_s1_h][O1_s1_w] = (O1[O1_s1_n][O1_s1_f][O1_s1_h][O1_s1_w] + (I1[O1_s1_n][O1_s1_r_c][(O1_s1_h + O1_s1_r_kh)][(O1_s1_w + O1_s1_r_kw)]*W1[O1_s1_r_c][O1_s1_f][O1_s1_r_kh][O1_s1_r_kw]));
155
+ }
156
+ }
157
+ }
158
+ }
159
+ }
160
+ }
161
+ }
162
+ )C" );
120
163
}
121
164
122
165
TEST_F (GenericHalideCoreTest, Copy) {
@@ -136,27 +179,55 @@ def fun(float(N, G, C, H, W) I1, float(G, C, F, KH, KW) W1) -> (O1) {
136
179
)TC" ;
137
180
CheckC (
138
181
tc,
139
- {" for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {" ,
140
- " for (int O1_s0_g = 0; O1_s0_g < G; O1_s0_g++) {" ,
141
- " for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) {" ,
142
- " for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {" ,
143
- " for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {" ,
144
- " O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f" ,
145
- " for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {" ,
146
- " for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {" ,
147
- " for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {" ,
148
- " O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = (O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] + (I1[O1_s0_n][O1_s0_g][O1_s1_r_c][(O1_s0_h + O1_s1_r_kh)][(O1_s0_w + O1_s1_r_kw)]*W1[O1_s0_g][O1_s1_r_c][O1_s0_f][O1_s1_r_kh][O1_s1_r_kw]))" });
182
+ R"C(
183
+ for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) {
184
+ for (int O1_s0_g = 0; O1_s0_g < G; O1_s0_g++) {
185
+ for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) {
186
+ for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) {
187
+ for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) {
188
+ O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f;
189
+ }
190
+ }
191
+ }
192
+ }
193
+ }
194
+ for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) {
195
+ for (int O1_s1_g = 0; O1_s1_g < G; O1_s1_g++) {
196
+ for (int O1_s1_f = 0; O1_s1_f < F; O1_s1_f++) {
197
+ for (int O1_s1_h = 0; O1_s1_h < ((H - KH) + 1); O1_s1_h++) {
198
+ for (int O1_s1_w = 0; O1_s1_w < ((W - KW) + 1); O1_s1_w++) {
199
+ for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) {
200
+ for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) {
201
+ for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) {
202
+ O1[O1_s1_n][O1_s1_g][O1_s1_f][O1_s1_h][O1_s1_w] = (O1[O1_s1_n][O1_s1_g][O1_s1_f][O1_s1_h][O1_s1_w] + (I1[O1_s1_n][O1_s1_g][O1_s1_r_c][(O1_s1_h + O1_s1_r_kh)][(O1_s1_w + O1_s1_r_kw)]*W1[O1_s1_g][O1_s1_r_c][O1_s1_f][O1_s1_r_kh][O1_s1_r_kw]));
203
+ }
204
+ }
205
+ }
206
+ }
207
+ }
208
+ }
209
+ }
210
+ }
211
+ )C" );
149
212
}
150
213
151
214
TEST_F (GenericHalideCoreTest, Matmul) {
152
215
CheckC (
153
216
makeMatmulTc (false , false ),
154
- std::vector<std::string>{
155
- " for (int O_s0_i = 0; O_s0_i < N; O_s0_i++) {" ,
156
- " for (int O_s0_j = 0; O_s0_j < M; O_s0_j++) {" ,
157
- " O[O_s0_i][O_s0_j] = 0.000000f;" ,
158
- " for (int O_s1_k = 0; O_s1_k < K; O_s1_k++) {" ,
159
- " O[O_s0_i][O_s0_j] = (O[O_s0_i][O_s0_j] + (A[O_s0_i][O_s1_k]*B[O_s1_k][O_s0_j]));" });
217
+ R"C(
218
+ for (int O_s0_i = 0; O_s0_i < N; O_s0_i++) {
219
+ for (int O_s0_j = 0; O_s0_j < M; O_s0_j++) {
220
+ O[O_s0_i][O_s0_j] = 0.000000f;
221
+ }
222
+ }
223
+ for (int O_s1_i = 0; O_s1_i < N; O_s1_i++) {
224
+ for (int O_s1_j = 0; O_s1_j < M; O_s1_j++) {
225
+ for (int O_s1_k = 0; O_s1_k < K; O_s1_k++) {
226
+ O[O_s1_i][O_s1_j] = (O[O_s1_i][O_s1_j] + (A[O_s1_i][O_s1_k]*B[O_s1_k][O_s1_j]));
227
+ }
228
+ }
229
+ }
230
+ )C" );
160
231
}
161
232
162
233
using namespace isl ::with_exceptions;
0 commit comments