Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit ac4ce08

Browse files
author
Jules Pondard
committed
Add profile_kernel function to PyBinds
This function gives the time of execution of a given kernel with an input and specified options. Useful for benchmarking purposes.
1 parent f8bdb8e commit ac4ce08

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tensor_comprehensions/pybinds/tclib.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,17 @@ struct TcExecutor {
274274
return tupleOrTensor(convertToPyObjects(atOutputs));
275275
}
276276
}
277+
278+
size_t profile_kernel(const py::tuple& inputs, const py::tuple& outputs) {
279+
auto atInputs = getATenTensors(inputs);
280+
auto atOutputs = (outputs.size() > 0)
281+
? getATenTensors(outputs)
282+
: tc::aten::prepareOutputs(tc, entryPoint, atInputs);
283+
tc::ProfilingInfo profinfo =
284+
tc::aten::profile(*executor, atInputs, atOutputs);
285+
return profinfo.kernelRuntime.toMicroSeconds();
286+
}
287+
277288
std::string tc;
278289
std::string entryPoint;
279290
std::unique_ptr<tc::CudaBackend::ExecutorType> executor;
@@ -485,7 +496,13 @@ PYBIND11_MODULE(tclib, m) {
485496
"unchecked_run",
486497
&TcExecutor::uncheckedRun,
487498
py::arg("inputs"),
499+
py::arg("outputs") = py::tuple())
500+
.def(
501+
"profile_kernel",
502+
&TcExecutor::profile_kernel,
503+
py::arg("inputs"),
488504
py::arg("outputs") = py::tuple());
505+
489506
m.def(
490507
"compile",
491508
[](const std::string& tc,

0 commit comments

Comments
 (0)