Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 11b0c31

Browse files
Add simpler pybind entry point to TC
If the python user can memoize the executor resulting from compilation then this API should be preferred. Otherwise, like in PyTorch autograd functions, the compilation cache should be used.
1 parent bfde5ea commit 11b0c31

File tree

2 files changed

+100
-15
lines changed

2 files changed

+100
-15
lines changed

python/examples/tc_pybind_example.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717

1818
dump_backward_overhead = False
1919

20+
################################################################################
21+
# The purpose of these examples is to demonstrate the usage of the python
22+
# bindings to build a simple, low-overhead, python abstraction.
23+
# We demonstrate the bnidings by building a series of examples leading to a
24+
# MultiTcFunction abstraction for PyTorch autograd.
25+
################################################################################
26+
2027
################################################################################
2128
# 0. Initializations
2229
################################################################################
@@ -33,7 +40,7 @@ def time_tc(iters, prepend, runFun, tc_name, inputs):
3340
start = time.clock()
3441
if dump_backward_overhead:
3542
dump_backward_overhead = time.clock()
36-
outputs = runFun(tc_name, inputs, ())
43+
outputs = runFun(tc_name, inputs)
3744
timesCPU.append(time.clock() - start)
3845
torch.cuda.synchronize()
3946
timesCPUAndGPU.append(time.clock() - start)
@@ -68,23 +75,51 @@ def matmul_grad(float(M,N) A, float(N,K) B, float(M,K) d_O) -> (d_A, d_B) {
6875
mat1, mat2 = torch.randn(300, 400).cuda(), torch.randn(400, 500).cuda()
6976

7077
################################################################################
71-
# 1. Use the C++ API to build a low-overhead compilation cache and time it
78+
# 1. Use the simple high-overhead compile/run C++ API
79+
# If one can keep state in their layer or wishes to experiment with TC,
80+
# this is a simple entry point.
81+
# If state cannot be kept, be aware that this API has a non-trivial overhead
82+
# when outputs sizes need to be inferred and outputs allocated.
83+
# Compilation itself has a prohibitive cost and needs to be memoized either
84+
# by holding on to the executor or by using the low-overhead abstraction, see
85+
# below
86+
################################################################################
87+
from tensor_comprehensions.tclib import compile
88+
89+
executor = compile(mm, "matmul", (mat1, mat2), MappingOptions())
90+
outputs = executor.run((mat1, mat2), ())
91+
outputs = executor.unchecked_run((mat1, mat2), tuple(outputs))
92+
time_tc(100,
93+
"simple API\t",
94+
lambda name, ins: executor.unchecked_run(ins, tuple(outputs)),
95+
"matmul",
96+
(mat1, mat2))
97+
time_tc(100,
98+
"simple API (with allocation overhead)\t",
99+
lambda name, ins: executor.unchecked_run(ins, ()),
100+
"matmul",
101+
(mat1, mat2))
102+
103+
################################################################################
104+
# 2. Use the C++ API to build a low-overhead compilation cache and time it
72105
################################################################################
73106
from tensor_comprehensions.tclib import CompilationCache
74107

75108
compilation_cache = CompilationCache(mm)
76109
# Compilation returns an allocated tuple of outputs with the proper shapes.
77110
# Allocation overhead is negligible compared to compilation overhead.
78111
compilation_cache.compile("matmul", (mat1, mat2), MappingOptions())
112+
# Run once without timing
113+
compilation_cache.unchecked_run("matmul", (mat1, mat2), ())
79114
# unchecked_run on tensors
80115
time_tc(100,
81116
"raw unchecked_run naive options\t",
82-
lambda name, ins, outs: compilation_cache.unchecked_run(name, ins, outs),
117+
lambda name, ins: compilation_cache.unchecked_run(name, ins, ()),
83118
"matmul",
84119
(mat1, mat2))
85120

86121
################################################################################
87-
# 2. Short tuning run saving to file then load the best option to create a
122+
# 3. Short tuning run saving to file then load the best option to create a
88123
# compilation cache
89124
################################################################################
90125
from tensor_comprehensions.tclib import Tuner
@@ -111,12 +146,12 @@ def matmul_grad(float(M,N) A, float(N,K) B, float(M,K) d_O) -> (d_A, d_B) {
111146
compilation_cache.compile("matmul", (mat1, mat2), top1)
112147
time_tc(100,
113148
"raw unchecked_run tuned options\t",
114-
lambda name, ins, outs: compilation_cache.unchecked_run(name, ins, outs),
149+
lambda name, ins: compilation_cache.unchecked_run(name, ins, ()),
115150
"matmul",
116151
(mat1, mat2))
117152

118153
################################################################################
119-
# 3. Simple TC builder
154+
# 4. Simple TC builder
120155
################################################################################
121156
class TcBuilder():
122157
def __init__(self,
@@ -200,12 +235,12 @@ def compileOrTune(self, name = "", force_reinforcement_tuning = False, inputs =
200235
tcb.compileOrTune(name = "matmul", inputs = (mat1, mat2))
201236
time_tc(100,
202237
"TcBuilder unchecked_run\t",
203-
lambda name, ins, outs: tcb.compilation_cache.unchecked_run(name, ins, outs),
238+
lambda name, ins: tcb.compilation_cache.unchecked_run(name, ins, ()),
204239
"matmul",
205240
(mat1, mat2))
206241

207242
################################################################################
208-
# 4. Simple torch.autograd.Function backed by TcBuilder
243+
# 5. Simple torch.autograd.Function backed by TcBuilder
209244
################################################################################
210245
class TcFunction(torch.autograd.Function):
211246
@staticmethod
@@ -283,7 +318,7 @@ def backward(ctx, *gradients):
283318

284319
time_tc(100,
285320
"TcFunction forward unchecked_run\t",
286-
lambda name, ins, outs: TcFunction.apply(tcb, *ins),
321+
lambda name, ins: TcFunction.apply(tcb, *ins),
287322
"matmul",
288323
(mat1, mat2))
289324

@@ -306,7 +341,7 @@ def backward(ctx, *gradients):
306341
dump_backward_overhead = False
307342
time_tc(100,
308343
"TcFunction backward unchecked_run\t",
309-
lambda name, ins, outs: outputs[0].backward(grad_sized_tensor, retain_graph = True),
344+
lambda name, ins: outputs[0].backward(grad_sized_tensor, retain_graph = True),
310345
"matmul",
311346
(mat1, mat2))
312347

@@ -316,7 +351,7 @@ def backward(ctx, *gradients):
316351
v.backward(retain_graph = True)
317352

318353
################################################################################
319-
# 5. Multi-TC builder
354+
# 6. Multi-TC builder
320355
################################################################################
321356
class MultiTcBuilder():
322357
def __init__(self,
@@ -404,12 +439,12 @@ def compileOrTune(self, name = "", force_reinforcement_tuning = False, inputs =
404439
tcb.compileOrTune(name = "matmul", inputs = (mat1, mat2))
405440
time_tc(100,
406441
"MultiTcBuilder unchecked_run\t",
407-
lambda name, ins, outs: tcb.compilation_cache.unchecked_run(name, ins, outs),
442+
lambda name, ins: tcb.compilation_cache.unchecked_run(name, ins, ()),
408443
"matmul",
409444
(mat1, mat2))
410445

411446
################################################################################
412-
# 6. Multi-TC torch.autograd.Function backed by MultiTcBuilder
447+
# 7. Multi-TC torch.autograd.Function backed by MultiTcBuilder
413448
################################################################################
414449
class MultiTcFunction(torch.autograd.Function):
415450
@staticmethod
@@ -508,7 +543,7 @@ def backward(ctx, *gradients):
508543

509544
time_tc(100,
510545
"MultiTcFunction forward unchecked_run\t",
511-
lambda name, ins, outs: MultiTcFunction.apply(tcb, *ins),
546+
lambda name, ins: MultiTcFunction.apply(tcb, *ins),
512547
"matmul",
513548
(mat1, mat2))
514549

@@ -531,7 +566,7 @@ def backward(ctx, *gradients):
531566
dump_backward_overhead = False
532567
time_tc(100,
533568
"MultiTcFunction backward unchecked_run\t",
534-
lambda name, ins, outs: outputs[0].backward(grad_sized_tensor, retain_graph = True),
569+
lambda name, ins: outputs[0].backward(grad_sized_tensor, retain_graph = True),
535570
"matmul",
536571
(mat1, mat2))
537572

tensor_comprehensions/pybinds/tclib.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,42 @@ class Tuner : public ATenCudaGeneticTuner {
236236
std::string cacheFileName;
237237
};
238238

239+
struct TcExecutor {
240+
py::list run(
241+
const py::tuple& inputs,
242+
const py::tuple& outputs = py::tuple()) {
243+
if (outputs.size() > 0) {
244+
auto atOutputs = getATenTensors(outputs);
245+
auto atInputs = getATenTensors(inputs);
246+
tc::aten::run(*executor, atInputs, atOutputs);
247+
return py::list(outputs);
248+
} else {
249+
auto atInputs = getATenTensors(inputs);
250+
auto atOutputs = tc::aten::prepareOutputs(tc, entryPoint, atInputs);
251+
tc::aten::run(*executor, atInputs, atOutputs);
252+
return convertToPyObjects(atOutputs);
253+
}
254+
}
255+
py::list uncheckedRun(
256+
const py::tuple& inputs,
257+
const py::tuple& outputs = py::tuple()) {
258+
if (outputs.size() > 0) {
259+
auto atOutputs = getATenTensors(outputs);
260+
auto atInputs = getATenTensors(inputs);
261+
tc::aten::uncheckedRun(*executor, atInputs, atOutputs);
262+
return py::list(outputs);
263+
} else {
264+
auto atInputs = getATenTensors(inputs);
265+
auto atOutputs = tc::aten::prepareOutputs(tc, entryPoint, atInputs);
266+
tc::aten::uncheckedRun(*executor, atInputs, atOutputs);
267+
return convertToPyObjects(atOutputs);
268+
}
269+
}
270+
std::string tc;
271+
std::string entryPoint;
272+
std::unique_ptr<tc::CudaBackend::ExecutorType> executor;
273+
};
274+
239275
class TunerConfig {
240276
public:
241277
TunerConfig(
@@ -345,6 +381,20 @@ PYBIND11_MODULE(tclib, m) {
345381
m.def(
346382
"set_dump_cuda", [](bool dump_cuda) { tc::FLAGS_dump_cuda = dump_cuda; });
347383

384+
py::class_<TcExecutor>(m, "TcExecutor", py::module_local())
385+
.def("run", &TcExecutor::run)
386+
.def("unchecked_run", &TcExecutor::uncheckedRun);
387+
m.def(
388+
"compile",
389+
[](const std::string& tc,
390+
const std::string& entryPoint,
391+
const py::tuple& inputs,
392+
const tc::CudaMappingOptions& options) {
393+
auto execUPtr = tc::aten::compile<tc::CudaBackend>(
394+
tc, entryPoint, getATenTensors(inputs), options);
395+
return TcExecutor{tc, entryPoint, std::move(execUPtr)};
396+
});
397+
348398
py::class_<TunerConfig>(m, "TunerConfig", py::module_local())
349399
.def(
350400
py::init<uint32_t, uint32_t, uint32_t, std::string, bool, uint32_t>(),

0 commit comments

Comments
 (0)