This repository was archived by the owner on Apr 28, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed
tensor_comprehensions/pybinds Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -274,6 +274,17 @@ struct TcExecutor {
274
274
return tupleOrTensor (convertToPyObjects (atOutputs));
275
275
}
276
276
}
277
+
278
+ size_t profile_kernel (const py::tuple& inputs, const py::tuple& outputs) {
279
+ auto atInputs = getATenTensors (inputs);
280
+ auto atOutputs = (outputs.size () > 0 )
281
+ ? getATenTensors (outputs)
282
+ : tc::aten::prepareOutputs (tc, entryPoint, atInputs);
283
+ tc::ProfilingInfo profinfo =
284
+ tc::aten::profile (*executor, atInputs, atOutputs);
285
+ return profinfo.kernelRuntime .toMicroSeconds ();
286
+ }
287
+
277
288
std::string tc;
278
289
std::string entryPoint;
279
290
std::unique_ptr<tc::CudaBackend::ExecutorType> executor;
@@ -485,7 +496,13 @@ PYBIND11_MODULE(tclib, m) {
485
496
" unchecked_run" ,
486
497
&TcExecutor::uncheckedRun,
487
498
py::arg (" inputs" ),
499
+ py::arg (" outputs" ) = py::tuple ())
500
+ .def (
501
+ " profile_kernel" ,
502
+ &TcExecutor::profile_kernel,
503
+ py::arg (" inputs" ),
488
504
py::arg (" outputs" ) = py::tuple ());
505
+
489
506
m.def (
490
507
" compile" ,
491
508
[](const std::string& tc,
You can’t perform that action at this time.
0 commit comments