@@ -461,45 +461,13 @@ bool hasOuterSequentialMember(
461
461
return false ;
462
462
}
463
463
464
+ // Name of the space of blocks inside the grid
465
+ constexpr auto kGrid = " grid" ;
464
466
// Name of the space of threads inside a block
465
467
constexpr auto kBlock = " block" ;
466
468
// Name of the space of warps
467
469
constexpr auto kWarp = " warp" ;
468
470
469
- /*
470
- * Extract a mapping from the domain elements active at "tree"
471
- * to the thread identifiers, where all branches in "tree"
472
- * are assumed to have been mapped to thread identifiers.
473
- * "nThread" is the number of thread identifiers.
474
- * The result lives in a space of the form block[x, ...].
475
- */
476
- isl::multi_union_pw_aff extractDomainToThread (
477
- const detail::ScheduleTree* tree,
478
- size_t nThread) {
479
- using namespace polyhedral ::detail;
480
-
481
- auto space = isl::space (tree->ctx_ , 0 );
482
- auto empty = isl::union_set::empty (space);
483
- auto id = isl::id (tree->ctx_ , kBlock );
484
- space = space.named_set_from_params_id (id, nThread);
485
- auto zero = isl::multi_val::zero (space);
486
- auto domainToThread = isl::multi_union_pw_aff (empty, zero);
487
-
488
- for (auto mapping : tree->collect (tree, ScheduleTreeType::MappingFilter)) {
489
- auto mappingNode = mapping->elemAs <ScheduleTreeElemMappingFilter>();
490
- auto list = isl::union_pw_aff_list (tree->ctx_ , nThread);
491
- for (size_t i = 0 ; i < nThread; ++i) {
492
- auto threadId = mapping::ThreadId::makeId (i);
493
- auto threadMap = mappingNode->mapping .at (threadId);
494
- list = list.add (threadMap);
495
- }
496
- auto nodeToThread = isl::multi_union_pw_aff (space, list);
497
- domainToThread = domainToThread.union_add (nodeToThread);
498
- }
499
-
500
- return domainToThread;
501
- }
502
-
503
471
/*
504
472
* Construct a mapping
505
473
*
@@ -534,6 +502,26 @@ isl::multi_aff constructThreadToWarp(
534
502
}
535
503
} // namespace
536
504
505
+ isl::multi_union_pw_aff MappedScop::threadMappingSchedule (
506
+ const detail::ScheduleTree* tree) const {
507
+ std::vector<mapping::MappingId> ids;
508
+ for (size_t i = 0 ; i < numThreads.view .size (); ++i) {
509
+ ids.emplace_back (mapping::ThreadId::makeId (i));
510
+ }
511
+ auto tupleId = isl::id (tree->ctx_ , kBlock );
512
+ return extractDomainToIds (scop_->scheduleRoot (), tree, ids, tupleId);
513
+ }
514
+
515
+ isl::multi_union_pw_aff MappedScop::blockMappingSchedule (
516
+ const detail::ScheduleTree* tree) const {
517
+ std::vector<mapping::MappingId> ids;
518
+ for (size_t i = 0 ; i < numBlocks.view .size (); ++i) {
519
+ ids.emplace_back (mapping::BlockId::makeId (i));
520
+ }
521
+ auto tupleId = isl::id (tree->ctx_ , kGrid );
522
+ return extractDomainToIds (scop_->scheduleRoot (), tree, ids, tupleId);
523
+ }
524
+
537
525
Scop::SyncLevel MappedScop::findBestSync (
538
526
detail::ScheduleTree* st1,
539
527
detail::ScheduleTree* st2,
@@ -724,7 +712,7 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
724
712
725
713
auto outer = hasOuterSequentialMember (scop_->scheduleRoot (), seq);
726
714
727
- auto domainToThread = extractDomainToThread (seq, numThreads. view . size () );
715
+ auto domainToThread = threadMappingSchedule (seq);
728
716
auto threadToWarp = constructThreadToWarp (seq->ctx_ , 32 , numThreads);
729
717
auto domainToWarp = domainToThread.apply (threadToWarp);
730
718
@@ -1080,7 +1068,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
1080
1068
1081
1069
// 9. Promote to registers below the loops mapped to threads.
1082
1070
if (cudaOptions.proto ().use_private_memory ()) {
1083
- promoteToRegistersBelowThreads (mappedScop-> scop () , -1ull );
1071
+ promoteToRegistersBelowThreads (* mappedScop, -1ull );
1084
1072
}
1085
1073
1086
1074
LOG_IF (INFO, FLAGS_debug_tc_mapper)
0 commit comments