@@ -431,119 +431,84 @@ bool hasOuterSequentialMember(
431
431
return false ;
432
432
}
433
433
434
- // Intersect the union set with all the mapping
435
- // filters params in the given schedule tree
436
- isl::union_set intersectMappingFilterParams (
437
- detail::ScheduleTree* st,
438
- isl::union_set us) {
439
- if (auto filter = st->elemAsBase <detail::ScheduleTreeElemFilter>()) {
440
- us = us.intersect (filter->filter_ );
441
- }
434
+ // Name of the space of threads inside a block
435
+ constexpr auto kBlock = " block" ;
436
+ // Name of the space of warps
437
+ constexpr auto kWarp = " warp" ;
442
438
443
- auto children = st->children ();
444
- auto nChildren = children.size ();
445
- if (nChildren == 1 ) {
446
- us = intersectMappingFilterParams (children[0 ], us);
447
- } else if (nChildren > 1 ) {
448
- auto usParent = us;
449
- us = intersectMappingFilterParams (children[0 ], us);
450
- for (size_t i = 1 ; i < nChildren; ++i) {
451
- us = us.unite (intersectMappingFilterParams (children[i], usParent));
452
- }
453
- }
454
-
455
- return us;
456
- }
439
+ /*
440
+ * Extract a mapping from the domain elements active at "tree"
441
+ * to the thread identifiers, where all branches in "tree"
442
+ * are assumed to have been mapped to thread identifiers.
443
+ * "nThread" is the number of thread identifiers.
444
+ * The result lives in a space of the form block[x, ...].
445
+ */
446
+ isl::multi_union_pw_aff extractDomainToThread (
447
+ const detail::ScheduleTree* tree,
448
+ size_t nThread) {
449
+ using namespace polyhedral ::detail;
457
450
458
- // Change the name of the isl ids tied to threads and blocks
459
- // by adding a suffix
460
- isl::union_set modifyMappingNames (
461
- isl::union_set set,
462
- const std::string suffix) {
463
- USING_MAPPING_SHORT_NAMES (BX, BY, BZ, TX, TY, TZ);
464
- std::unordered_set<isl::id, isl::IslIdIslHash> identifiers{
465
- BX, BY, BZ, TX, TY, TZ};
466
-
467
- auto space = set.get_space ();
468
- for (auto id : identifiers) {
469
- auto name = id.get_name ();
470
- auto dim = space.find_dim_by_name (isl::dim_type::param, id.get_name ());
471
- CHECK_LE (0 , dim);
472
- space = space.set_dim_name (isl::dim_type::param, dim, name + suffix);
473
- }
474
- auto newSet = isl::union_set::empty (space);
475
- set.foreach_set ([&newSet, &identifiers, &suffix](isl::set setInFun) {
476
- for (auto id : identifiers) {
477
- auto name = id.get_name ();
478
- auto dim =
479
- setInFun.get_space ().find_dim_by_name (isl::dim_type::param, name);
480
- CHECK_LE (0 , dim);
481
- setInFun =
482
- setInFun.set_dim_name (isl::dim_type::param, dim, name + suffix);
451
+ auto space = isl::space (tree->ctx_ , 0 );
452
+ auto empty = isl::union_set::empty (space);
453
+ auto id = isl::id (tree->ctx_ , kBlock );
454
+ space = space.named_set_from_params_id (id, nThread);
455
+ auto zero = isl::multi_val::zero (space);
456
+ auto domainToThread = isl::multi_union_pw_aff (empty, zero);
457
+
458
+ for (auto mapping : tree->collect (tree, ScheduleTreeType::MappingFilter)) {
459
+ auto mappingNode = mapping->elemAs <ScheduleTreeElemMappingFilter>();
460
+ auto list = isl::union_pw_aff_list (tree->ctx_ , nThread);
461
+ for (size_t i = 0 ; i < nThread; ++i) {
462
+ auto threadId = mapping::ThreadId::makeId (i);
463
+ auto threadMap = mappingNode->mapping .at (threadId);
464
+ list = list.add (threadMap);
483
465
}
484
- newSet = newSet.unite (setInFun);
485
- });
486
- return newSet;
487
- }
488
-
489
- // Get the formula computing the linearized index of a thread in a block.
490
- isl::aff getLinearizedThreadIdxFormula (
491
- isl::space space,
492
- const Block& block,
493
- const std::string& suffix = " " ) {
494
- USING_MAPPING_SHORT_NAMES (BX, BY, BZ, TX, TY, TZ);
495
- std::vector<std::pair<isl::id, unsigned >> mappingIds{
496
- {TX, TX.mappingSize (block)},
497
- {TY, TY.mappingSize (block)},
498
- {TZ, TZ.mappingSize (block)}};
499
-
500
- isl::aff formula = isl::aff (isl::local_space (space));
501
-
502
- for (int i = (int )mappingIds.size () - 1 ; i >= 0 ; --i) {
503
- auto name = mappingIds[i].first .to_str ();
504
- auto dim = space.find_dim_by_name (isl::dim_type::param, name + suffix);
505
- CHECK_LE (0 , dim);
506
- auto id = space.get_dim_id (isl::dim_type::param, dim);
507
- isl::aff aff (isl::aff::param_on_domain_space (space, id));
508
- formula = formula * mappingIds[i].second + aff;
466
+ auto nodeToThread = isl::multi_union_pw_aff (space, list);
467
+ domainToThread = domainToThread.union_add (nodeToThread);
509
468
}
510
469
511
- return formula ;
470
+ return domainToThread ;
512
471
}
513
472
514
- // Return the constraints ensuring that the points with parameters
515
- // [t0,t1,t2] and [t0',t1',t2'] are in the same warp.
516
- // (where t0 is "t0" + suffix1 and t0' is "t0" + suffix2)
517
- // if suffix1 is "_1" and suffix2 is "_2", the constraint is in the form
518
- // ((t0_1 + a * t1_1 + b * t2_1) / warpSize).floor()
519
- // == ((t0_2 + a' * t1_2 + b' * t1_2) / warpSize).floor()
520
- // with t0_1 + a * t1_1 + b * t2_1 the linearized formula of the thread index.
521
- // This function returns a set because it might change in the future,
522
- // and take into account the blocks.
523
- isl::set getSameWarpConstraints (
524
- isl::space space,
525
- const std::string& suffix1,
526
- const std::string& suffix2 ,
473
+ /*
474
+ * Construct a mapping
475
+ *
476
+ * block[x] -> warp[floor((x)/warpSize)]
477
+ * block[x, y] -> warp[floor((x + s_x * (y))/ warpSize)]
478
+ * block[x, y, z] -> warp[floor((x + s_x * (y + s_y * (z)))/ warpSize)]
479
+ *
480
+ * uniquely mapping thread identifiers that belong to the same warp
481
+ * (of size "warpSize") to a warp identifier,
482
+ * based on the thread sizes s_x, s_y up to s_z in "block".
483
+ */
484
+ isl::multi_aff constructThreadToWarp (
485
+ isl::ctx ctx ,
527
486
const unsigned warpSize,
528
487
const Block& block) {
529
- USING_MAPPING_SHORT_NAMES (BX, BY, BZ, TX, TY, TZ);
530
- std::vector<std::pair<isl::id, unsigned >> mappingIds{
531
- {TX, TX.mappingSize (block)},
532
- {TY, TY.mappingSize (block)},
533
- {TZ, TZ.mappingSize (block)}};
534
-
535
- auto formula1 = getLinearizedThreadIdxFormula (space, block, suffix1);
536
- auto formula2 = getLinearizedThreadIdxFormula (space, block, suffix2);
488
+ auto space = isl::space (ctx, 0 );
489
+ auto id = isl::id (ctx, kBlock );
490
+ auto blockSpace = space.named_set_from_params_id (id, block.view .size ());
491
+ auto warpSpace = space.named_set_from_params_id (isl::id (ctx, kWarp ), 1 );
492
+ auto aff = isl::aff::zero_on_domain (blockSpace);
493
+
494
+ auto nThread = block.view .size ();
495
+ auto identity = isl::multi_aff::identity (blockSpace.map_from_set ());
496
+ for (int i = nThread - 1 ; i >= 0 ; --i) {
497
+ aff = aff.scale (isl::val (ctx, block.view [i]));
498
+ aff = aff.add (identity.get_aff (i));
499
+ }
537
500
538
- return (
539
- isl::aff_set ((formula1 / warpSize). floor ()) ==
540
- (formula2 / warpSize). floor ( ));
501
+ aff = aff. scale_down ( isl::val (ctx, warpSize)). floor ();
502
+ auto mapSpace = blockSpace. product (warpSpace). unwrap ();
503
+ return isl::multi_aff (mapSpace, isl::aff_list (aff ));
541
504
}
542
505
} // namespace
543
506
544
507
Scop::SyncLevel MappedScop::findBestSync (
545
508
detail::ScheduleTree* st1,
546
- detail::ScheduleTree* st2) {
509
+ detail::ScheduleTree* st2,
510
+ isl::multi_union_pw_aff domainToThread,
511
+ isl::multi_union_pw_aff domainToWarp) {
547
512
// Active points in the two schedule trees
548
513
auto stRoot = scop_->scheduleRoot ();
549
514
auto activePoints1 = activeDomainPointsBelow (stRoot, st1);
@@ -557,41 +522,16 @@ Scop::SyncLevel MappedScop::findBestSync(
557
522
return Scop::SyncLevel::None;
558
523
}
559
524
560
- // The domain and the context of the root schedule tree
561
- auto domainAndContext = scop_->domain ();
562
525
CHECK_LE (1u , scop_->scheduleRoot ()->children ().size ());
563
526
auto contextSt = scop_->scheduleRoot ()->children ()[0 ];
564
527
auto contextElem = contextSt->elemAs <detail::ScheduleTreeElemContext>();
565
528
CHECK (nullptr != contextElem);
566
- domainAndContext = domainAndContext .intersect_params (contextElem->context_ );
529
+ dependences = dependences .intersect_params (contextElem->context_ );
567
530
568
- // The domain of both schedule trees filtered by mapping filters,
569
- // and then modified to have different threads and blocks names.
570
- auto domain1 = intersectMappingFilterParams (st1, domainAndContext);
571
- auto domain2 = intersectMappingFilterParams (st2, domainAndContext);
572
- auto suffix1 = " _1" ;
573
- auto suffix2 = " _2" ;
574
- domain1 = modifyMappingNames (domain1, suffix1);
575
- domain2 = modifyMappingNames (domain2, suffix2);
576
-
577
- // The dependences between the two schedule trees
578
- // with mapping from threads and blocks
579
- auto mappedDependences =
580
- isl::union_map::from_domain_and_range (domain1, domain2);
581
- mappedDependences = mappedDependences.intersect (dependences);
582
-
583
- auto space = mappedDependences.get_space ();
584
- auto sameThreadConstraint =
585
- getSameWarpConstraints (space, suffix1, suffix2, 1 , numThreads);
586
- auto sameWarpConstraints =
587
- getSameWarpConstraints (space, suffix1, suffix2, 32 , numThreads);
588
-
589
- if (mappedDependences ==
590
- mappedDependences.intersect_params (sameThreadConstraint)) {
531
+ if (dependences.is_subset (dependences.eq_at (domainToThread))) {
591
532
return Scop::SyncLevel::None;
592
- } else if (
593
- mappedDependences ==
594
- mappedDependences.intersect_params (sameWarpConstraints)) {
533
+ }
534
+ if (dependences.is_subset (dependences.eq_at (domainToWarp))) {
595
535
return Scop::SyncLevel::Warp;
596
536
}
597
537
return Scop::SyncLevel::Block;
@@ -754,6 +694,10 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
754
694
755
695
auto outer = hasOuterSequentialMember (scop_->scheduleRoot (), seq);
756
696
697
+ auto domainToThread = extractDomainToThread (seq, numThreads.view .size ());
698
+ auto threadToWarp = constructThreadToWarp (seq->ctx_ , 32 , numThreads);
699
+ auto domainToWarp = domainToThread.apply (threadToWarp);
700
+
757
701
std::vector<std::vector<int >> bestSync (
758
702
nChildren, std::vector<int >(nChildren + 1 ));
759
703
// Get the synchronization needed between children[i] and children[i+k]
@@ -765,7 +709,8 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
765
709
for (size_t i = 0 ; i < nChildren; ++i) {
766
710
for (size_t k = 0 ; k < nChildren; ++k) {
767
711
auto ik = (i + k) % nChildren;
768
- bestSync[i][k] = (int )findBestSync (children[i], children[ik]);
712
+ bestSync[i][k] = (int )findBestSync (
713
+ children[i], children[ik], domainToThread, domainToWarp);
769
714
}
770
715
}
771
716
0 commit comments