Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit e0ea796

Browse files
Add reference checks for python tests
1 parent c6fc827 commit e0ea796

File tree

1 file changed

+87
-31
lines changed

1 file changed

+87
-31
lines changed

python/tests/test_tc.py

Lines changed: 87 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ def avgpool(float(B, C, H, W) input) -> (output) {
129129
T = tc.define(tc_str, tc.make_naive_options_factory())
130130
inp = torch.ones(1, 1, 4, 4, device='cuda')
131131
out = T.avgpool(inp)
132-
# TODO: test results!!!
132+
133+
from torch.nn.modules.pooling import AvgPool2d
134+
ref = AvgPool2d(2, stride=1).forward(inp)
135+
tc.assert_almost_equal(ref, out, inp)
133136

134137
#
135138
# This test implements group normalization as a single TC kernel.
@@ -138,13 +141,16 @@ def avgpool(float(B, C, H, W) input) -> (output) {
138141
def test_group_norm_fused(self):
139142
group_normalization = """
140143
def group_normalization(
141-
float(N, G, D, H, W) I, float(G, D) gamma, float(G, D) beta) -> (Sum, SumSq, O)
144+
float(N, G, D, H, W) I, float(G, D) gamma, float(G, D) beta)
145+
-> (Sum, SumSq, O)
142146
{
143147
Sum(n, g) +=! I(n, g, r_d, r_h, r_w)
144148
SumSq(n, g) +=! I(n, g, r_d, r_h, r_w) * I(n, g, r_d, r_h, r_w)
145-
O(n, g, d, h, w) = gamma(g, d)
149+
O(n, g, d, h, w) = gamma(g, d)
146150
* ( I(n, g, d, h, w) - Sum(n, g) / (D * H * W))
147-
* rsqrt( (SumSq(n, g) / (D * H * W) - Sum(n, g) * Sum(n, g)) + 1e-5 )
151+
* rsqrt( (SumSq(n, g) - Sum(n, g) * Sum(n, g) / (D * H * W))
152+
/ (D * H * W)
153+
+ 1e-5)
148154
+ beta(g, d)
149155
}
150156
"""
@@ -157,10 +163,15 @@ def group_normalization(
157163
tuner_config=tuner_config))
158164
I, gamma, beta = (
159165
torch.randn(N, G, D, H, W, device='cuda'),
160-
torch.randn(G, D, device='cuda'),
161-
torch.randn(G, D, device='cuda'))
166+
torch.randn(G, D, device='cuda').fill_(1.0),
167+
torch.randn(G, D, device='cuda').zero_())
162168
Sum, SumSq, O = T.group_normalization(I, gamma, beta)
163-
# TODO: test results!!!
169+
170+
from torch.nn.modules.normalization import GroupNorm
171+
GN = GroupNorm(G, G*D).cuda()
172+
ref = GN.forward(I.view((N, G*D, H, W)))
173+
174+
tc.assert_almost_equal(ref, O.view((N, G*D, H, W)), I, operations=D*H*W)
164175

165176
#
166177
# This test implements group normalization as 2 TC kernels
@@ -191,8 +202,8 @@ def group_normalization(
191202
N, G, D, H, W = 32, 32, 4, 56, 56
192203
I, gamma, beta = (
193204
torch.randn(N, G, D, H, W, device='cuda'),
194-
torch.randn(G, D, device='cuda'),
195-
torch.randn(G, D, device='cuda'))
205+
torch.randn(G, D, device='cuda').fill_(1.0),
206+
torch.randn(G, D, device='cuda').zero_())
196207

197208
T = tc.define(
198209
group_normalization,
@@ -208,7 +219,12 @@ def group_normalization(
208219
mean, var = T.moments(I.view((N * G, -1)))
209220
out = T.group_normalization(
210221
I, gamma, beta, mean.view((N, G)), var.view((N, G)))
211-
# TODO: test results!!!
222+
223+
from torch.nn.modules.normalization import GroupNorm
224+
GN = GroupNorm(G, G*D).cuda()
225+
ref = GN.forward(I.view((N, G*D, H, W)))
226+
227+
tc.assert_almost_equal(ref, out.view((N, G*D, H, W)), I, operations=D*H*W)
212228

213229
#
214230
# TC example without fallback but with tuning starting from MappingOptions('naive').
@@ -239,8 +255,8 @@ def group_normalization(
239255
N, G, D, H, W = 32, 32, 4, 56, 56
240256
I, gamma, beta = (
241257
torch.randn(N, G, D, H, W, device='cuda'),
242-
torch.randn(G, D, device='cuda'),
243-
torch.randn(G, D, device='cuda'))
258+
torch.randn(G, D, device='cuda').fill_(1.0),
259+
torch.randn(G, D, device='cuda').zero_())
244260

245261
T = tc.define(
246262
group_normalization,
@@ -266,45 +282,63 @@ def group_normalization(
266282
out = T.group_normalization(
267283
I, gamma, beta, mean.view((N, G)), var.view((N, G)))
268284

285+
from torch.nn.modules.normalization import GroupNorm
286+
GN = GroupNorm(G, G*D).cuda()
287+
ref = GN.forward(I.view((N, G*D, H, W)))
288+
289+
tc.assert_almost_equal(ref, out.view((N, G*D, H, W)), I, operations=D*H*W)
290+
269291

270292
#
271293
# This tests single kernel forward/backward with tc.make_autograd.
272294
#
273295
def test_conv_with_backward_fused(self):
274296
conv = """
275-
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
297+
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) Bias)
298+
-> (O)
299+
{
276300
O(n, m, h, w) +=!
277301
I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
302+
O(n, m, h, w) = O(n, m, h, w) + Bias(m)
278303
}
279304
def convolution_grad(
280-
float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) d_O)
281-
-> (d_I, d_W1)
305+
float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) Bias, float(N,M,H,W) d_O)
306+
-> (d_I, d_W1, d_Bias)
282307
{
283308
d_I(n, c, h, w) +=!
284309
d_O( n, r_m, h - r_kh, w - r_kw) * W1(r_m, c, r_kh, r_kw)
285310
d_W1(m, c, kh, kw) +=!
286311
d_O(r_n, m, r_h - kh, r_w - kw) * I(r_n, c, r_h, r_w)
312+
# TODO: Bias incorrect + check
313+
d_Bias(m) = Bias(m)
287314
}
288315
"""
289316

290317
N, C, H, W, O, kH, kW = 32, 4, 56, 56, 16, 1, 1
291-
I, W = (
292-
torch.randn(N, C, H, W, device='cuda', requires_grad=True),
293-
torch.randn(O, C, kH, kW, device='cuda', requires_grad=True))
318+
I = torch.randn(N, C, H, W, device='cuda', requires_grad=True)
294319
T = tc.define(
295320
conv,
296321
tc.make_autotuned_options_factory(
297322
starting_options='naive',
298323
tuner_config=tuner_config))
299324
convolution = tc.make_autograd(T.convolution, T.convolution_grad)
300325

326+
# Reference
327+
from torch.nn.modules.conv import Conv2d
328+
Conv = Conv2d(C, O, 1, stride=1).cuda()
329+
ref = Conv.forward(I)
330+
331+
W = Conv.weight.clone()
332+
Bias = Conv.bias.clone()
333+
301334
# First occurrence triggers tuning (make_autotuned_options_factory)
302-
out = convolution(I, W)
335+
out = convolution(I, W, Bias)
303336
out.sum().backward()
304337

305-
out = convolution(I, W)
338+
out = convolution(I, W, Bias)
306339
out.sum().backward()
307-
# TODO: test results!!!
340+
341+
tc.assert_almost_equal(ref, out, I, operations=C * kH * kW)
308342

309343
#
310344
# This tests 1-kernel forward/ 2-kernel backward with tc.make_autograd.
@@ -314,9 +348,12 @@ def convolution_grad(
314348
#
315349
def test_conv_with_backward_2kernels(self):
316350
conv = """
317-
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
351+
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) Bias)
352+
-> (O)
353+
{
318354
O(n, m, h, w) +=!
319355
I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
356+
O(n, m, h, w) = O(n, m, h, w) + Bias(m)
320357
}
321358
def convolution_igrad(float(M,C,KH,KW) W1, float(N,M,H,W) d_O)
322359
-> (d_I)
@@ -329,6 +366,11 @@ def convolution_wgrad(float(N,C,H,W) I, float(N,M,H,W) d_O) -> (d_W1)
329366
d_W1(m, c, kh, kw) +=!
330367
d_O(r_n, m, r_h - kh, r_w - kw) * I(r_n, c, r_h, r_w)
331368
}
369+
def convolution_biasgrad(float(M) Bias) -> (d_Bias)
370+
{
371+
# TODO: Bias incorrect + check
372+
d_Bias(m) = Bias(m)
373+
}
332374
"""
333375

334376
N, C, H, W, O, kH, kW = 32, 4, 56, 56, 16, 1, 1
@@ -337,26 +379,34 @@ def convolution_wgrad(float(N,C,H,W) I, float(N,M,H,W) d_O) -> (d_W1)
337379
tc.make_autotuned_options_factory(
338380
starting_options='naive',
339381
tuner_config=tuner_config))
340-
I, W = (
341-
torch.randn(N, C, H, W, device='cuda', requires_grad=True),
342-
torch.randn(O, C, kH, kW, device='cuda', requires_grad=True))
382+
I = torch.randn(N, C, H, W, device='cuda', requires_grad=True)
383+
384+
# Reference
385+
from torch.nn.modules.conv import Conv2d
386+
Conv = Conv2d(C, O, 1, stride=1).cuda()
387+
ref = Conv.forward(I)
343388

344-
def convolution_backward(I, W, d_O):
389+
W = Conv.weight.clone()
390+
Bias = Conv.bias.clone()
391+
392+
def convolution_backward(I, W, Bias, d_O):
345393
d_I = T.convolution_igrad(W, d_O)
346394
d_O = T.convolution_wgrad(I, d_O)
347-
return (d_I, d_O)
395+
d_Bias = T.convolution_biasgrad(Bias)
396+
return (d_I, d_O, d_Bias)
348397

349398
convolution_function = tc.make_autograd(
350399
T.convolution, convolution_backward)
351400

352401
# First occurrence triggers tuning
353-
out = convolution_function(I, W)
402+
out = convolution_function(I, W, Bias)
354403
out.sum().backward()
355404

356405
# Subsequent occurrences do not
357-
out = convolution_function(I, W)
406+
out = convolution_function(I, W, Bias)
358407
out.sum().backward()
359-
# TODO: test results!!!
408+
409+
tc.assert_almost_equal(ref, out, I, operations=C * kH * kW)
360410

361411
#
362412
# This tests the direct use of pybinds which are closer to C++
@@ -424,7 +474,13 @@ def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1)
424474
executor = tclib.compile(
425475
tensordot_str, entry_point, (I0, I1), best_options)
426476
O = executor.run((I0, I1), ())
427-
# TODO: test results!!!
477+
478+
# No simple torch baseline, compare against naive
479+
executor = tclib.compile(
480+
tensordot_str, entry_point, (I0, I1), tc.MappingOptions('naive'))
481+
ref = executor.run((I0, I1), ())
482+
483+
tc.assert_almost_equal(ref, O, I0, I1, operations=C2)
428484

429485
if __name__ == '__main__':
430486
unittest.main()

0 commit comments

Comments
 (0)