Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 32cc5cb

Browse files
Reduce TC + PyTorch overhead
This commit reduce overheads by avoiding a 2-step allocOutputs and run crossing the Python/C++ boundary. This also removes the synchronization issue mentioned in the previous commit but for which we will still need to find the root cause. ``` python ./test_python/pytorch_example.py raw unchecked_run naive options Total CPU time to launch kernel: min 19us, p50 29us, p90 34us, max 2204us raw unchecked_run naive options Total CPU launch + GPU kernel time: min 365us, p50 374us, p90 379us, max 2562us Tune with cache @ /tmp/d17dd046-de20-40ce-b52d-79cd3286fcb4 Note that if you pass a fixed filename, you can reinforce an existing tuning state Iteration 0 Jobs(Compiled, Evaluated)/total (25, 25)/25 (best/median/worst)us: 346/18111/196621 Iteration 1 Jobs(Compiled, Evaluated)/total (25, 25)/25 (best/median/worst)us: 346/905/1621 Iteration 2 Jobs(Compiled, Evaluated)/total (25, 25)/25 (best/median/worst)us: 335/739/1616 raw unchecked_run tuned options Total CPU time to launch kernel: min 17us, p50 21us, p90 23us, max 1030us raw unchecked_run tuned options Total CPU launch + GPU kernel time: min 350us, p50 354us, p90 360us, max 1365us TcBuilder unchecked_run Total CPU time to launch kernel: min 19us, p50 22us, p90 35us, max 2439us TcBuilder unchecked_run Total CPU launch + GPU kernel time: min 303us, p50 307us, p90 320us, max 2712us TcFunction forward unchecked_run Total CPU time to launch kernel: min 41us, p50 62us, p90 70us, max 857164us TcFunction forward unchecked_run Total CPU launch + GPU kernel time: min 317us, p50 338us, p90 351us, max 857281us TcFunction backward unchecked_run Total CPU time to launch kernel: min 344us, p50 388us, p90 412us, max 883us TcFunction backward unchecked_run Total CPU launch + GPU kernel time: min 1321us, p50 1351us, p90 1371us, max 1849us MultiTcBuilder unchecked_run Total CPU time to launch kernel: min 14us, p50 22us, p90 25us, max 1863us MultiTcBuilder unchecked_run Total CPU launch + GPU kernel time: min 298us, p50 305us, p90 310us, max 2136us MultiTcFunction forward unchecked_run Total CPU time to launch kernel: min 35us, p50 58us, p90 67us, max 506382us MultiTcFunction forward unchecked_run Total CPU launch + GPU kernel time: min 197us, p50 334us, p90 342us, max 506619us MultiTcFunction backward unchecked_run Total CPU time to launch kernel: min 275us, p50 364us, p90 383us, max 438us MultiTcFunction backward unchecked_run Total CPU launch + GPU kernel time: min 1265us, p50 1333us, p90 1350us, max 1379us ```
1 parent f0edebe commit 32cc5cb

File tree

2 files changed

+285
-77
lines changed

2 files changed

+285
-77
lines changed

tensor_comprehensions/pybinds/tclib.cc

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ void initGlog() {
5050
}
5151

5252
inline std::vector<tc::TensorInfo> getATenTensorsAsTensorInfo(
53-
py::tuple& pyTensors) {
53+
const py::tuple& pyTensors) {
5454
std::vector<tc::TensorInfo> tensors;
5555
for (auto& inp : pyTensors) {
5656
tensors.push_back(tc::aten::toTensorInfo(inp.cast<at::Tensor>()));
5757
}
5858
return tensors;
5959
}
6060

61-
inline std::vector<at::Tensor> getATenTensors(py::tuple& pyTensors) {
61+
inline std::vector<at::Tensor> getATenTensors(const py::tuple& pyTensors) {
6262
std::vector<at::Tensor> atTensors;
6363
for (auto& inp : pyTensors) {
6464
atTensors.push_back(inp.cast<at::Tensor>());
@@ -67,15 +67,16 @@ inline std::vector<at::Tensor> getATenTensors(py::tuple& pyTensors) {
6767
}
6868

6969
template <typename VoidPtr>
70-
inline std::vector<VoidPtr> getATenTensorsAsRawPtrs(py::tuple& pyTensors) {
70+
inline std::vector<VoidPtr> getATenTensorsAsRawPtrs(
71+
const py::tuple& pyTensors) {
7172
std::vector<VoidPtr> res;
7273
for (auto& inp : pyTensors) {
7374
res.push_back(static_cast<VoidPtr>(inp.cast<at::Tensor>().data_ptr()));
7475
}
7576
return res;
7677
}
7778

78-
inline py::list convertToPyObjects(std::vector<at::Tensor>& tensors) {
79+
inline py::list convertToPyObjects(const std::vector<at::Tensor>& tensors) {
7980
py::list outputs;
8081
for (auto& tensor : tensors) {
8182
outputs.append(py::cast(torch::autograd::make_variable(tensor)));
@@ -101,7 +102,7 @@ inline py::list convertToPyObjects(std::vector<at::Tensor>& tensors) {
101102
*/
102103
struct CompilationCache {
103104
struct Key {
104-
Key(std::string entryPt, py::tuple& inputTuple)
105+
Key(std::string entryPt, const py::tuple& inputTuple)
105106
: entryPoint(entryPt), inputs(getATenTensorsAsTensorInfo(inputTuple)) {}
106107
bool operator==(const Key& other) const {
107108
return entryPoint == other.entryPoint && inputs == other.inputs;
@@ -132,15 +133,17 @@ struct CompilationCache {
132133
initGlog();
133134
}
134135

135-
bool isCompiled(const std::string& entryPoint, py::tuple& inputs) {
136+
bool isCompiled(const std::string& entryPoint, const py::tuple& inputs) {
136137
return compiled.count(Key(entryPoint, inputs)) > 0;
137138
}
138139

139140
/// This function infers the size of the outputs for each new compilation.
140141
/// This brings overhead, therefore we memoize the output sizes on-demand.
141142
/// The allocation itself is backed by ATen's caching allocator and is
142143
/// assumed acceptable (this is used everywhere in PyTorch).
143-
py::list allocOutputs(const std::string& entryPoint, py::tuple& inputs) {
144+
std::vector<at::Tensor> allocATenOutputTensors(
145+
const std::string& entryPoint,
146+
const py::tuple& inputs) {
144147
Key k(entryPoint, inputs);
145148
auto kvp = outputs.find(k);
146149
if (kvp == outputs.end()) {
@@ -156,7 +159,13 @@ struct CompilationCache {
156159
outputs.emplace(k, atOutputs);
157160
kvp = outputs.find(k);
158161
}
159-
return convertToPyObjects(kvp->second);
162+
return kvp->second;
163+
}
164+
165+
py::list allocOutputs(
166+
const std::string& entryPoint,
167+
const py::tuple& inputs) {
168+
return convertToPyObjects(allocATenOutputTensors(entryPoint, inputs));
160169
}
161170

162171
/// This function forces recompilation and storage.
@@ -166,32 +175,48 @@ struct CompilationCache {
166175
/// compiled version given an entryPoint and inputs.
167176
void compile(
168177
const std::string& entryPoint,
169-
py::tuple& inputs,
178+
const py::tuple& inputs,
170179
const tc::CudaMappingOptions& options) {
171180
Key k(entryPoint, inputs);
172181
compiled[k] = tc::aten::compile<tc::CudaBackend>(
173182
tc, entryPoint, getATenTensors(inputs), options);
174183
}
175184

176-
void
177-
run(const std::string& entryPoint, py::tuple& inputs, py::tuple& outputs) {
178-
CHECK_GE(outputs.size(), 1u)
179-
<< "run needs a tuple of output tensors to write into";
180-
auto atInputs = getATenTensors(inputs);
181-
auto atOutputs = getATenTensors(outputs);
182-
tc::aten::run(*compiled.at(Key(entryPoint, inputs)), atInputs, atOutputs);
185+
py::list run(
186+
const std::string& entryPoint,
187+
const py::tuple& inputs,
188+
const py::tuple& outputs = py::tuple()) {
189+
if (outputs.size() > 0) {
190+
auto atOutputs = getATenTensors(outputs);
191+
auto atInputs = getATenTensors(inputs);
192+
tc::aten::run(*compiled.at(Key(entryPoint, inputs)), atInputs, atOutputs);
193+
return py::list(outputs);
194+
} else {
195+
auto atOutputs = allocATenOutputTensors(entryPoint, inputs);
196+
auto atInputs = getATenTensors(inputs);
197+
tc::aten::run(*compiled.at(Key(entryPoint, inputs)), atInputs, atOutputs);
198+
return convertToPyObjects(atOutputs);
199+
}
183200
}
184201

185-
void uncheckedRun(
202+
py::list uncheckedRun(
186203
const std::string& entryPoint,
187-
py::tuple& inputs,
188-
py::tuple& outputs) {
189-
CHECK_GE(outputs.size(), 1u)
190-
<< "uncheckedRun needs a tuple of output tensors to write into";
191-
compiled.at(Key(entryPoint, inputs))
192-
->uncheckedRun(
193-
getATenTensorsAsRawPtrs<const void*>(inputs),
194-
getATenTensorsAsRawPtrs<void*>(outputs));
204+
const py::tuple& inputs,
205+
const py::tuple& outputs = py::tuple()) {
206+
if (outputs.size() > 0) {
207+
compiled.at(Key(entryPoint, inputs))
208+
->uncheckedRun(
209+
getATenTensorsAsRawPtrs<const void*>(inputs),
210+
getATenTensorsAsRawPtrs<void*>(outputs));
211+
return py::list(outputs);
212+
} else {
213+
auto outputs = allocOutputs(entryPoint, inputs);
214+
compiled.at(Key(entryPoint, inputs))
215+
->uncheckedRun(
216+
getATenTensorsAsRawPtrs<const void*>(inputs),
217+
getATenTensorsAsRawPtrs<void*>(outputs));
218+
return outputs;
219+
}
195220
}
196221

197222
std::string tc;
@@ -276,7 +301,7 @@ class MappingOptionsCache {
276301
std::vector<tc::CudaMappingOptions> load(
277302
const std::string& tc,
278303
const std::string& entryPoint,
279-
py::tuple& inputTuple,
304+
const py::tuple& inputTuple,
280305
const size_t numCandidates) {
281306
tc::autotune::OptionsCache<tc::CudaBackend> cache;
282307
cache.loadCacheFromFile(fileName_);
@@ -400,7 +425,7 @@ PYBIND11_MODULE(tclib, m) {
400425
"tune",
401426
[](Tuner& instance,
402427
const std::string& entryPoint,
403-
py::tuple& inputs,
428+
const py::tuple& inputs,
404429
tc::CudaMappingOptions& baseMapping,
405430
const TunerConfig& config) {
406431
config.__enter__();

0 commit comments

Comments
 (0)