@@ -514,7 +514,11 @@ PYBIND11_MODULE(tclib, m) {
514
514
" MappingOptions for a Tensor Comprehensions (TC)" ,
515
515
py::module_local ())
516
516
.def (
517
- py::init ([]() {
517
+ py::init ([](const std::string& optionsName) {
518
+ TC_CHECK_EQ (optionsName, " naive" )
519
+ << " Naive options are the only constructible user-facing "
520
+ << " options. We recommended using the tuner to get better "
521
+ << " options or, alternatively, retrieving some from a cache." ;
518
522
return tc::CudaMappingOptions::makeNaiveMappingOptions ();
519
523
}),
520
524
" Initialize naive CudaMappingOption" )
@@ -543,6 +547,10 @@ PYBIND11_MODULE(tclib, m) {
543
547
&tc::CudaMappingOptions::useSharedMemory,
544
548
" Create block-local copies of data in shared memory when this can "
545
549
" leverage data reuse or global memory access coalescing" )
550
+ .def (
551
+ " usePrivateMemory" ,
552
+ &tc::CudaMappingOptions::usePrivateMemory,
553
+ " Create thread-local copies of data in private memory" )
546
554
.def (
547
555
" unrollCopyShared" ,
548
556
&tc::CudaMappingOptions::unrollCopyShared,
@@ -556,13 +564,15 @@ PYBIND11_MODULE(tclib, m) {
556
564
" scheduleFusionStrategy" ,
557
565
[](tc::CudaMappingOptions& instance, const std::string& type) {
558
566
instance.scheduleFusionStrategy (type);
567
+ return instance;
559
568
},
560
569
" Set up outerScheduleFusionStrategy and intraTileFusionStrategy "
561
570
" to the given value" )
562
571
.def (
563
572
" outerScheduleFusionStrategy" ,
564
573
[](tc::CudaMappingOptions& instance, const std::string& type) {
565
574
instance.outerScheduleFusionStrategy (type);
575
+ return instance;
566
576
},
567
577
" Require TC to try and execute different TC expressions interleaved "
568
578
" (Max), separately (Min)\n "
@@ -574,6 +584,7 @@ PYBIND11_MODULE(tclib, m) {
574
584
" intraTileScheduleFusionStrategy" ,
575
585
[](tc::CudaMappingOptions& instance, const std::string& type) {
576
586
instance.intraTileScheduleFusionStrategy (type);
587
+ return instance;
577
588
},
578
589
" Require TC to try and execute different TC expressions interleaved "
579
590
" (Max), separately (Min)\n "
@@ -584,7 +595,10 @@ PYBIND11_MODULE(tclib, m) {
584
595
" tile" ,
585
596
// pybind11 has implicit conversion from tuple -> vector
586
597
[](tc::CudaMappingOptions& instance,
587
- std::vector<uint64_t >& tileSizes) { instance.tile (tileSizes); },
598
+ std::vector<uint64_t >& tileSizes) {
599
+ instance.tile (tileSizes);
600
+ return instance;
601
+ },
588
602
" Perform loop tiling on the generated code with the given sizes. "
589
603
" Independent of mapping to a\n "
590
604
" grid of thread blocks" )
@@ -593,6 +607,7 @@ PYBIND11_MODULE(tclib, m) {
593
607
[](tc::CudaMappingOptions& instance,
594
608
std::vector<uint64_t >& threadSizes) {
595
609
instance.mapToThreads (threadSizes);
610
+ return instance;
596
611
},
597
612
" The configuration of CUDA block, i.e. the number of CUDA threads "
598
613
" in each block along three\n "
@@ -604,6 +619,7 @@ PYBIND11_MODULE(tclib, m) {
604
619
[](tc::CudaMappingOptions& instance,
605
620
std::vector<uint64_t >& blockSizes) {
606
621
instance.mapToBlocks (blockSizes);
622
+ return instance;
607
623
},
608
624
" The configuration of CUDA grid, i.e. the number of CUDA blocks "
609
625
" along three dimensions. Must be\n "
@@ -613,13 +629,15 @@ PYBIND11_MODULE(tclib, m) {
613
629
" matchLibraryCalls" ,
614
630
[](tc::CudaMappingOptions& instance, bool match) {
615
631
instance.matchLibraryCalls (match);
632
+ return instance;
616
633
},
617
634
" Replace computation patterns with calls to highly optimized "
618
635
" libraries (such as CUB, CUTLASS) when possible" )
619
636
.def (
620
637
" fixParametersBeforeScheduling" ,
621
638
[](tc::CudaMappingOptions& instance, bool fix) {
622
639
instance.fixParametersBeforeScheduling (fix);
640
+ return instance;
623
641
},
624
642
" Perform automatic loop scheduling taking into account specific "
625
643
" tensor sizes.\n "
@@ -631,6 +649,7 @@ PYBIND11_MODULE(tclib, m) {
631
649
" unroll" ,
632
650
[](tc::CudaMappingOptions& instance, uint64_t factor) {
633
651
instance.unroll (factor);
652
+ return instance;
634
653
},
635
654
" Perform loop unrolling on the generated code and produce at "
636
655
" most the given number of statements" );
0 commit comments