17
17
18
18
dump_backward_overhead = False
19
19
20
+ ################################################################################
21
+ # The purpose of these examples is to demonstrate the usage of the python
22
+ # bindings to build a simple, low-overhead, python abstraction.
23
+ # We demonstrate the bnidings by building a series of examples leading to a
24
+ # MultiTcFunction abstraction for PyTorch autograd.
25
+ ################################################################################
26
+
20
27
################################################################################
21
28
# 0. Initializations
22
29
################################################################################
@@ -33,7 +40,7 @@ def time_tc(iters, prepend, runFun, tc_name, inputs):
33
40
start = time .clock ()
34
41
if dump_backward_overhead :
35
42
dump_backward_overhead = time .clock ()
36
- outputs = runFun (tc_name , inputs , () )
43
+ outputs = runFun (tc_name , inputs )
37
44
timesCPU .append (time .clock () - start )
38
45
torch .cuda .synchronize ()
39
46
timesCPUAndGPU .append (time .clock () - start )
@@ -68,23 +75,51 @@ def matmul_grad(float(M,N) A, float(N,K) B, float(M,K) d_O) -> (d_A, d_B) {
68
75
mat1 , mat2 = torch .randn (300 , 400 ).cuda (), torch .randn (400 , 500 ).cuda ()
69
76
70
77
################################################################################
71
- # 1. Use the C++ API to build a low-overhead compilation cache and time it
78
+ # 1. Use the simple high-overhead compile/run C++ API
79
+ # If one can keep state in their layer or wishes to experiment with TC,
80
+ # this is a simple entry point.
81
+ # If state cannot be kept, be aware that this API has a non-trivial overhead
82
+ # when outputs sizes need to be inferred and outputs allocated.
83
+ # Compilation itself has a prohibitive cost and needs to be memoized either
84
+ # by holding on to the executor or by using the low-overhead abstraction, see
85
+ # below
86
+ ################################################################################
87
+ from tensor_comprehensions .tclib import compile
88
+
89
+ executor = compile (mm , "matmul" , (mat1 , mat2 ), MappingOptions ())
90
+ outputs = executor .run ((mat1 , mat2 ), ())
91
+ outputs = executor .unchecked_run ((mat1 , mat2 ), tuple (outputs ))
92
+ time_tc (100 ,
93
+ "simple API\t " ,
94
+ lambda name , ins : executor .unchecked_run (ins , tuple (outputs )),
95
+ "matmul" ,
96
+ (mat1 , mat2 ))
97
+ time_tc (100 ,
98
+ "simple API (with allocation overhead)\t " ,
99
+ lambda name , ins : executor .unchecked_run (ins , ()),
100
+ "matmul" ,
101
+ (mat1 , mat2 ))
102
+
103
+ ################################################################################
104
+ # 2. Use the C++ API to build a low-overhead compilation cache and time it
72
105
################################################################################
73
106
from tensor_comprehensions .tclib import CompilationCache
74
107
75
108
compilation_cache = CompilationCache (mm )
76
109
# Compilation returns an allocated tuple of outputs with the proper shapes.
77
110
# Allocation overhead is negligible compared to compilation overhead.
78
111
compilation_cache .compile ("matmul" , (mat1 , mat2 ), MappingOptions ())
112
+ # Run once without timing
113
+ compilation_cache .unchecked_run ("matmul" , (mat1 , mat2 ), ())
79
114
# unchecked_run on tensors
80
115
time_tc (100 ,
81
116
"raw unchecked_run naive options\t " ,
82
- lambda name , ins , outs : compilation_cache .unchecked_run (name , ins , outs ),
117
+ lambda name , ins : compilation_cache .unchecked_run (name , ins , () ),
83
118
"matmul" ,
84
119
(mat1 , mat2 ))
85
120
86
121
################################################################################
87
- # 2 . Short tuning run saving to file then load the best option to create a
122
+ # 3 . Short tuning run saving to file then load the best option to create a
88
123
# compilation cache
89
124
################################################################################
90
125
from tensor_comprehensions .tclib import Tuner
@@ -111,12 +146,12 @@ def matmul_grad(float(M,N) A, float(N,K) B, float(M,K) d_O) -> (d_A, d_B) {
111
146
compilation_cache .compile ("matmul" , (mat1 , mat2 ), top1 )
112
147
time_tc (100 ,
113
148
"raw unchecked_run tuned options\t " ,
114
- lambda name , ins , outs : compilation_cache .unchecked_run (name , ins , outs ),
149
+ lambda name , ins : compilation_cache .unchecked_run (name , ins , () ),
115
150
"matmul" ,
116
151
(mat1 , mat2 ))
117
152
118
153
################################################################################
119
- # 3 . Simple TC builder
154
+ # 4 . Simple TC builder
120
155
################################################################################
121
156
class TcBuilder ():
122
157
def __init__ (self ,
@@ -200,12 +235,12 @@ def compileOrTune(self, name = "", force_reinforcement_tuning = False, inputs =
200
235
tcb .compileOrTune (name = "matmul" , inputs = (mat1 , mat2 ))
201
236
time_tc (100 ,
202
237
"TcBuilder unchecked_run\t " ,
203
- lambda name , ins , outs : tcb .compilation_cache .unchecked_run (name , ins , outs ),
238
+ lambda name , ins : tcb .compilation_cache .unchecked_run (name , ins , () ),
204
239
"matmul" ,
205
240
(mat1 , mat2 ))
206
241
207
242
################################################################################
208
- # 4 . Simple torch.autograd.Function backed by TcBuilder
243
+ # 5 . Simple torch.autograd.Function backed by TcBuilder
209
244
################################################################################
210
245
class TcFunction (torch .autograd .Function ):
211
246
@staticmethod
@@ -283,7 +318,7 @@ def backward(ctx, *gradients):
283
318
284
319
time_tc (100 ,
285
320
"TcFunction forward unchecked_run\t " ,
286
- lambda name , ins , outs : TcFunction .apply (tcb , * ins ),
321
+ lambda name , ins : TcFunction .apply (tcb , * ins ),
287
322
"matmul" ,
288
323
(mat1 , mat2 ))
289
324
@@ -306,7 +341,7 @@ def backward(ctx, *gradients):
306
341
dump_backward_overhead = False
307
342
time_tc (100 ,
308
343
"TcFunction backward unchecked_run\t " ,
309
- lambda name , ins , outs : outputs [0 ].backward (grad_sized_tensor , retain_graph = True ),
344
+ lambda name , ins : outputs [0 ].backward (grad_sized_tensor , retain_graph = True ),
310
345
"matmul" ,
311
346
(mat1 , mat2 ))
312
347
@@ -316,7 +351,7 @@ def backward(ctx, *gradients):
316
351
v .backward (retain_graph = True )
317
352
318
353
################################################################################
319
- # 5 . Multi-TC builder
354
+ # 6 . Multi-TC builder
320
355
################################################################################
321
356
class MultiTcBuilder ():
322
357
def __init__ (self ,
@@ -404,12 +439,12 @@ def compileOrTune(self, name = "", force_reinforcement_tuning = False, inputs =
404
439
tcb .compileOrTune (name = "matmul" , inputs = (mat1 , mat2 ))
405
440
time_tc (100 ,
406
441
"MultiTcBuilder unchecked_run\t " ,
407
- lambda name , ins , outs : tcb .compilation_cache .unchecked_run (name , ins , outs ),
442
+ lambda name , ins : tcb .compilation_cache .unchecked_run (name , ins , () ),
408
443
"matmul" ,
409
444
(mat1 , mat2 ))
410
445
411
446
################################################################################
412
- # 6 . Multi-TC torch.autograd.Function backed by MultiTcBuilder
447
+ # 7 . Multi-TC torch.autograd.Function backed by MultiTcBuilder
413
448
################################################################################
414
449
class MultiTcFunction (torch .autograd .Function ):
415
450
@staticmethod
@@ -508,7 +543,7 @@ def backward(ctx, *gradients):
508
543
509
544
time_tc (100 ,
510
545
"MultiTcFunction forward unchecked_run\t " ,
511
- lambda name , ins , outs : MultiTcFunction .apply (tcb , * ins ),
546
+ lambda name , ins : MultiTcFunction .apply (tcb , * ins ),
512
547
"matmul" ,
513
548
(mat1 , mat2 ))
514
549
@@ -531,7 +566,7 @@ def backward(ctx, *gradients):
531
566
dump_backward_overhead = False
532
567
time_tc (100 ,
533
568
"MultiTcFunction backward unchecked_run\t " ,
534
- lambda name , ins , outs : outputs [0 ].backward (grad_sized_tensor , retain_graph = True ),
569
+ lambda name , ins : outputs [0 ].backward (grad_sized_tensor , retain_graph = True ),
535
570
"matmul" ,
536
571
(mat1 , mat2 ))
537
572
0 commit comments