This repository was archived by the owner on Apr 28, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed
tensor_comprehensions/pybinds Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -273,6 +273,17 @@ struct TcExecutor {
273
273
return tupleOrTensor (convertToPyObjects (atOutputs));
274
274
}
275
275
}
276
+
277
+ size_t profile_kernel (const py::tuple& inputs, const py::tuple& outputs) {
278
+ auto atInputs = getATenTensors (inputs);
279
+ auto atOutputs = (outputs.size () > 0 )
280
+ ? getATenTensors (outputs)
281
+ : tc::aten::prepareOutputs (tc, entryPoint, atInputs);
282
+ tc::ProfilingInfo profinfo =
283
+ tc::aten::profile (*executor, atInputs, atOutputs);
284
+ return profinfo.kernelRuntime .toMicroSeconds ();
285
+ }
286
+
276
287
std::string tc;
277
288
std::string entryPoint;
278
289
std::unique_ptr<tc::CudaBackend::ExecutorType> executor;
@@ -465,7 +476,13 @@ PYBIND11_MODULE(tclib, m) {
465
476
" unchecked_run" ,
466
477
&TcExecutor::uncheckedRun,
467
478
py::arg (" inputs" ),
479
+ py::arg (" outputs" ) = py::tuple ())
480
+ .def (
481
+ " profile_kernel" ,
482
+ &TcExecutor::profile_kernel,
483
+ py::arg (" inputs" ),
468
484
py::arg (" outputs" ) = py::tuple ());
485
+
469
486
m.def (
470
487
" compile" ,
471
488
[](const std::string& tc,
You can’t perform that action at this time.
0 commit comments