Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit ab0982f

Browse files
author
Jules Pondard
committed
Add profile_kernel function to PyBinds
This function gives the time of execution of a given kernel with an input and specified options. Useful for benchmarking purposes.
1 parent 0899ee1 commit ab0982f

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tensor_comprehensions/pybinds/tclib.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,17 @@ struct TcExecutor {
273273
return tupleOrTensor(convertToPyObjects(atOutputs));
274274
}
275275
}
276+
277+
size_t profile_kernel(const py::tuple& inputs, const py::tuple& outputs) {
278+
auto atInputs = getATenTensors(inputs);
279+
auto atOutputs = (outputs.size() > 0)
280+
? getATenTensors(outputs)
281+
: tc::aten::prepareOutputs(tc, entryPoint, atInputs);
282+
tc::ProfilingInfo profinfo =
283+
tc::aten::profile(*executor, atInputs, atOutputs);
284+
return profinfo.kernelRuntime.toMicroSeconds();
285+
}
286+
276287
std::string tc;
277288
std::string entryPoint;
278289
std::unique_ptr<tc::CudaBackend::ExecutorType> executor;
@@ -465,7 +476,13 @@ PYBIND11_MODULE(tclib, m) {
465476
"unchecked_run",
466477
&TcExecutor::uncheckedRun,
467478
py::arg("inputs"),
479+
py::arg("outputs") = py::tuple())
480+
.def(
481+
"profile_kernel",
482+
&TcExecutor::profile_kernel,
483+
py::arg("inputs"),
468484
py::arg("outputs") = py::tuple());
485+
469486
m.def(
470487
"compile",
471488
[](const std::string& tc,

0 commit comments

Comments
 (0)