Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit c607c68

Browse files
author
Jules Pondard
committed
Add getDict function to PyBinds
Given a CudaMappingOptions instance, return a Python dictionary that represents all the given options and their values.
1 parent 4cd991c commit c607c68

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tensor_comprehensions/pybinds/tclib.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,38 @@ PYBIND11_MODULE(tclib, m) {
674674
return str;
675675
},
676676
"Returns the CudaMappingOptions as a human-readable string")
677+
.def(
678+
"getDict",
679+
[](tc::CudaMappingOptions& instance) {
680+
py::dict rv;
681+
rv["outerScheduleFusionStrategy"] = FusionStrategy_Name(
682+
instance.generic.outerScheduleOptions.proto.fusion_strategy());
683+
if (instance.generic.proto.has_intra_tile_schedule_options())
684+
rv["intraTileScheduleFusionStrategy"] =
685+
FusionStrategy_Name(instance.generic.intraTileScheduleOptions
686+
.proto.fusion_strategy());
687+
rv["fixParametersBeforeScheduling"] =
688+
instance.generic.proto.fix_parameters_before_scheduling();
689+
if (instance.generic.proto.has_tiling())
690+
rv["tile"] = instance.generic.tiling.extractVector();
691+
if (instance.generic.proto.has_unroll())
692+
rv["unroll"] = instance.generic.proto.unroll();
693+
rv["tileImperfectlyNested"] =
694+
instance.generic.proto.tile_imperfectly_nested();
695+
rv["matchLibraryCalls"] =
696+
instance.generic.proto.match_library_calls();
697+
rv["mapToThreads"] = instance.block.extractVector();
698+
rv["mapToBlocks"] = instance.grid.extractVector();
699+
rv["useSharedMemory"] = instance.proto().use_shared_memory();
700+
rv["usePrivateMemory"] = instance.proto().use_private_memory();
701+
rv["unrollCopyShared"] = instance.proto().unroll_copy_shared();
702+
rv["useReadOnlyCache"] = instance.proto().use_readonly_cache();
703+
if (instance.proto().has_max_shared_memory())
704+
rv["maxSharedMemory"] = instance.proto().max_shared_memory();
705+
rv["privateDepth"] = instance.proto().private_depth();
706+
return rv;
707+
},
708+
"Returns a dictionary with the CudaMappingOptions")
677709
.def(
678710
"serialize",
679711
[](tc::CudaMappingOptions& instance) {

0 commit comments

Comments
 (0)