@@ -129,7 +129,10 @@ def avgpool(float(B, C, H, W) input) -> (output) {
129
129
T = tc .define (tc_str , tc .make_naive_options_factory ())
130
130
inp = torch .ones (1 , 1 , 4 , 4 , device = 'cuda' )
131
131
out = T .avgpool (inp )
132
- # TODO: test results!!!
132
+
133
+ from torch .nn .modules .pooling import AvgPool2d
134
+ ref = AvgPool2d (2 , stride = 1 ).forward (inp )
135
+ tc .assert_almost_equal (ref , out , inp )
133
136
134
137
#
135
138
# This test implements group normalization as a single TC kernel.
@@ -138,13 +141,16 @@ def avgpool(float(B, C, H, W) input) -> (output) {
138
141
def test_group_norm_fused (self ):
139
142
group_normalization = """
140
143
def group_normalization(
141
- float(N, G, D, H, W) I, float(G, D) gamma, float(G, D) beta) -> (Sum, SumSq, O)
144
+ float(N, G, D, H, W) I, float(G, D) gamma, float(G, D) beta)
145
+ -> (Sum, SumSq, O)
142
146
{
143
147
Sum(n, g) +=! I(n, g, r_d, r_h, r_w)
144
148
SumSq(n, g) +=! I(n, g, r_d, r_h, r_w) * I(n, g, r_d, r_h, r_w)
145
- O(n, g, d, h, w) = gamma(g, d)
149
+ O(n, g, d, h, w) = gamma(g, d)
146
150
* ( I(n, g, d, h, w) - Sum(n, g) / (D * H * W))
147
- * rsqrt( (SumSq(n, g) / (D * H * W) - Sum(n, g) * Sum(n, g)) + 1e-5 )
151
+ * rsqrt( (SumSq(n, g) - Sum(n, g) * Sum(n, g) / (D * H * W))
152
+ / (D * H * W)
153
+ + 1e-5)
148
154
+ beta(g, d)
149
155
}
150
156
"""
@@ -157,10 +163,15 @@ def group_normalization(
157
163
tuner_config = tuner_config ))
158
164
I , gamma , beta = (
159
165
torch .randn (N , G , D , H , W , device = 'cuda' ),
160
- torch .randn (G , D , device = 'cuda' ),
161
- torch .randn (G , D , device = 'cuda' ))
166
+ torch .randn (G , D , device = 'cuda' ). fill_ ( 1.0 ) ,
167
+ torch .randn (G , D , device = 'cuda' ). zero_ () )
162
168
Sum , SumSq , O = T .group_normalization (I , gamma , beta )
163
- # TODO: test results!!!
169
+
170
+ from torch .nn .modules .normalization import GroupNorm
171
+ GN = GroupNorm (G , G * D ).cuda ()
172
+ ref = GN .forward (I .view ((N , G * D , H , W )))
173
+
174
+ tc .assert_almost_equal (ref , O .view ((N , G * D , H , W )), I , operations = D * H * W )
164
175
165
176
#
166
177
# This test implements group normalization as 2 TC kernels
@@ -191,8 +202,8 @@ def group_normalization(
191
202
N , G , D , H , W = 32 , 32 , 4 , 56 , 56
192
203
I , gamma , beta = (
193
204
torch .randn (N , G , D , H , W , device = 'cuda' ),
194
- torch .randn (G , D , device = 'cuda' ),
195
- torch .randn (G , D , device = 'cuda' ))
205
+ torch .randn (G , D , device = 'cuda' ). fill_ ( 1.0 ) ,
206
+ torch .randn (G , D , device = 'cuda' ). zero_ () )
196
207
197
208
T = tc .define (
198
209
group_normalization ,
@@ -208,7 +219,12 @@ def group_normalization(
208
219
mean , var = T .moments (I .view ((N * G , - 1 )))
209
220
out = T .group_normalization (
210
221
I , gamma , beta , mean .view ((N , G )), var .view ((N , G )))
211
- # TODO: test results!!!
222
+
223
+ from torch .nn .modules .normalization import GroupNorm
224
+ GN = GroupNorm (G , G * D ).cuda ()
225
+ ref = GN .forward (I .view ((N , G * D , H , W )))
226
+
227
+ tc .assert_almost_equal (ref , out .view ((N , G * D , H , W )), I , operations = D * H * W )
212
228
213
229
#
214
230
# TC example without fallback but with tuning starting from MappingOptions('naive').
@@ -239,8 +255,8 @@ def group_normalization(
239
255
N , G , D , H , W = 32 , 32 , 4 , 56 , 56
240
256
I , gamma , beta = (
241
257
torch .randn (N , G , D , H , W , device = 'cuda' ),
242
- torch .randn (G , D , device = 'cuda' ),
243
- torch .randn (G , D , device = 'cuda' ))
258
+ torch .randn (G , D , device = 'cuda' ). fill_ ( 1.0 ) ,
259
+ torch .randn (G , D , device = 'cuda' ). zero_ () )
244
260
245
261
T = tc .define (
246
262
group_normalization ,
@@ -266,45 +282,63 @@ def group_normalization(
266
282
out = T .group_normalization (
267
283
I , gamma , beta , mean .view ((N , G )), var .view ((N , G )))
268
284
285
+ from torch .nn .modules .normalization import GroupNorm
286
+ GN = GroupNorm (G , G * D ).cuda ()
287
+ ref = GN .forward (I .view ((N , G * D , H , W )))
288
+
289
+ tc .assert_almost_equal (ref , out .view ((N , G * D , H , W )), I , operations = D * H * W )
290
+
269
291
270
292
#
271
293
# This tests single kernel forward/backward with tc.make_autograd.
272
294
#
273
295
def test_conv_with_backward_fused (self ):
274
296
conv = """
275
- def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
297
+ def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) Bias)
298
+ -> (O)
299
+ {
276
300
O(n, m, h, w) +=!
277
301
I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
302
+ O(n, m, h, w) = O(n, m, h, w) + Bias(m)
278
303
}
279
304
def convolution_grad(
280
- float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) d_O)
281
- -> (d_I, d_W1)
305
+ float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) Bias, float( N,M,H,W) d_O)
306
+ -> (d_I, d_W1, d_Bias )
282
307
{
283
308
d_I(n, c, h, w) +=!
284
309
d_O( n, r_m, h - r_kh, w - r_kw) * W1(r_m, c, r_kh, r_kw)
285
310
d_W1(m, c, kh, kw) +=!
286
311
d_O(r_n, m, r_h - kh, r_w - kw) * I(r_n, c, r_h, r_w)
312
+ # TODO: Bias incorrect + check
313
+ d_Bias(m) = Bias(m)
287
314
}
288
315
"""
289
316
290
317
N , C , H , W , O , kH , kW = 32 , 4 , 56 , 56 , 16 , 1 , 1
291
- I , W = (
292
- torch .randn (N , C , H , W , device = 'cuda' , requires_grad = True ),
293
- torch .randn (O , C , kH , kW , device = 'cuda' , requires_grad = True ))
318
+ I = torch .randn (N , C , H , W , device = 'cuda' , requires_grad = True )
294
319
T = tc .define (
295
320
conv ,
296
321
tc .make_autotuned_options_factory (
297
322
starting_options = 'naive' ,
298
323
tuner_config = tuner_config ))
299
324
convolution = tc .make_autograd (T .convolution , T .convolution_grad )
300
325
326
+ # Reference
327
+ from torch .nn .modules .conv import Conv2d
328
+ Conv = Conv2d (C , O , 1 , stride = 1 ).cuda ()
329
+ ref = Conv .forward (I )
330
+
331
+ W = Conv .weight .clone ()
332
+ Bias = Conv .bias .clone ()
333
+
301
334
# First occurrence triggers tuning (make_autotuned_options_factory)
302
- out = convolution (I , W )
335
+ out = convolution (I , W , Bias )
303
336
out .sum ().backward ()
304
337
305
- out = convolution (I , W )
338
+ out = convolution (I , W , Bias )
306
339
out .sum ().backward ()
307
- # TODO: test results!!!
340
+
341
+ tc .assert_almost_equal (ref , out , I , operations = C * kH * kW )
308
342
309
343
#
310
344
# This tests 1-kernel forward/ 2-kernel backward with tc.make_autograd.
@@ -314,9 +348,12 @@ def convolution_grad(
314
348
#
315
349
def test_conv_with_backward_2kernels (self ):
316
350
conv = """
317
- def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
351
+ def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) Bias)
352
+ -> (O)
353
+ {
318
354
O(n, m, h, w) +=!
319
355
I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
356
+ O(n, m, h, w) = O(n, m, h, w) + Bias(m)
320
357
}
321
358
def convolution_igrad(float(M,C,KH,KW) W1, float(N,M,H,W) d_O)
322
359
-> (d_I)
@@ -329,6 +366,11 @@ def convolution_wgrad(float(N,C,H,W) I, float(N,M,H,W) d_O) -> (d_W1)
329
366
d_W1(m, c, kh, kw) +=!
330
367
d_O(r_n, m, r_h - kh, r_w - kw) * I(r_n, c, r_h, r_w)
331
368
}
369
+ def convolution_biasgrad(float(M) Bias) -> (d_Bias)
370
+ {
371
+ # TODO: Bias incorrect + check
372
+ d_Bias(m) = Bias(m)
373
+ }
332
374
"""
333
375
334
376
N , C , H , W , O , kH , kW = 32 , 4 , 56 , 56 , 16 , 1 , 1
@@ -337,26 +379,34 @@ def convolution_wgrad(float(N,C,H,W) I, float(N,M,H,W) d_O) -> (d_W1)
337
379
tc .make_autotuned_options_factory (
338
380
starting_options = 'naive' ,
339
381
tuner_config = tuner_config ))
340
- I , W = (
341
- torch .randn (N , C , H , W , device = 'cuda' , requires_grad = True ),
342
- torch .randn (O , C , kH , kW , device = 'cuda' , requires_grad = True ))
382
+ I = torch .randn (N , C , H , W , device = 'cuda' , requires_grad = True )
383
+
384
+ # Reference
385
+ from torch .nn .modules .conv import Conv2d
386
+ Conv = Conv2d (C , O , 1 , stride = 1 ).cuda ()
387
+ ref = Conv .forward (I )
343
388
344
- def convolution_backward (I , W , d_O ):
389
+ W = Conv .weight .clone ()
390
+ Bias = Conv .bias .clone ()
391
+
392
+ def convolution_backward (I , W , Bias , d_O ):
345
393
d_I = T .convolution_igrad (W , d_O )
346
394
d_O = T .convolution_wgrad (I , d_O )
347
- return (d_I , d_O )
395
+ d_Bias = T .convolution_biasgrad (Bias )
396
+ return (d_I , d_O , d_Bias )
348
397
349
398
convolution_function = tc .make_autograd (
350
399
T .convolution , convolution_backward )
351
400
352
401
# First occurrence triggers tuning
353
- out = convolution_function (I , W )
402
+ out = convolution_function (I , W , Bias )
354
403
out .sum ().backward ()
355
404
356
405
# Subsequent occurrences do not
357
- out = convolution_function (I , W )
406
+ out = convolution_function (I , W , Bias )
358
407
out .sum ().backward ()
359
- # TODO: test results!!!
408
+
409
+ tc .assert_almost_equal (ref , out , I , operations = C * kH * kW )
360
410
361
411
#
362
412
# This tests the direct use of pybinds which are closer to C++
@@ -424,7 +474,13 @@ def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1)
424
474
executor = tclib .compile (
425
475
tensordot_str , entry_point , (I0 , I1 ), best_options )
426
476
O = executor .run ((I0 , I1 ), ())
427
- # TODO: test results!!!
477
+
478
+ # No simple torch baseline, compare against naive
479
+ executor = tclib .compile (
480
+ tensordot_str , entry_point , (I0 , I1 ), tc .MappingOptions ('naive' ))
481
+ ref = executor .run ((I0 , I1 ), ())
482
+
483
+ tc .assert_almost_equal (ref , O , I0 , I1 , operations = C2 )
428
484
429
485
if __name__ == '__main__' :
430
486
unittest .main ()
0 commit comments