@@ -391,34 +391,6 @@ TensorReferenceGroup::referenceIds() const {
391
391
}
392
392
393
393
namespace {
394
- bool hasCopyExtensionSingleChild (const ScheduleTree* tree) {
395
- if (tree->numChildren () != 1 ) {
396
- return false ;
397
- }
398
-
399
- auto extensionNode =
400
- tree->child ({0 })->elemAs <detail::ScheduleTreeElemExtension>();
401
- if (!extensionNode) {
402
- return false ;
403
- }
404
-
405
- if ((tree->child ({0 })->numChildren () != 1 ) &&
406
- (tree->child ({0 , 0 })->elemAs <detail::ScheduleTreeElemSequence>())) {
407
- return false ;
408
- }
409
-
410
- for (auto e : isl::UnionAsVector<isl::union_map>(extensionNode->extension_ )) {
411
- if (!e.has_tuple_name (isl::dim_type::out)) {
412
- return false ;
413
- }
414
- if (e.get_tuple_name (isl::dim_type::out) != kReadIdName &&
415
- e.get_tuple_name (isl::dim_type::out) != kWriteIdName ) {
416
- return false ;
417
- }
418
- }
419
- return true ;
420
- }
421
-
422
394
// Construct the set containing all tensor elements.
423
395
//
424
396
// Find the Halide image corresponding to the given tensorId. Transform its
@@ -524,48 +496,26 @@ ScheduleTree* insertCopiesUnder(
524
496
bool reads = !group.scopedReads ().is_empty ();
525
497
bool writes = !group.scopedWrites ().is_empty ();
526
498
527
- if (hasCopyExtensionSingleChild (tree)) {
528
- auto extensionNode = tree->child ({0 });
529
- auto sequenceNode = tree->child ({0 , 0 });
530
-
531
- auto & ext =
532
- extensionNode->elemAs <detail::ScheduleTreeElemExtension>()->extension_ ;
533
- if (reads) {
534
- ext = ext.unite (isl::union_map (readExtension));
535
- sequenceNode->insertChild (0 , std::move (readFilterNode));
536
- }
537
- if (writes) {
538
- ext = ext.unite (isl::union_map (writeExtension));
539
- sequenceNode->appendChild (std::move (writeFilterNode));
540
- }
541
- return tree;
499
+ if (tree->numChildren () == 0 ) {
500
+ // The point underneath a leaf node cannot be referenced,
501
+ // so insert a dummy sequence first. It will be extended
502
+ // with the reads and/or writes.
503
+ insertSequenceBelow (root, tree);
542
504
}
543
505
544
- auto mainCompFilter = activeDomainPoints (root, tree).universe ();
545
- auto mainCompFilterNode =
546
- ScheduleTree::makeFilter (mainCompFilter, tree->detachChildren ());
547
-
548
- // XXX: I don't really like the syntax-imposed impossibility to create a
549
- // sequence node with no children.
550
- auto sequenceNode = ScheduleTree::makeSequence (
551
- reads ? std::move (readFilterNode) : std::move (mainCompFilterNode));
552
506
if (reads) {
553
- sequenceNode->appendChild (std::move (mainCompFilterNode));
507
+ insertExtensionBefore (
508
+ root, tree, tree->child ({0 }), readExtension, std::move (readFilterNode));
554
509
}
555
510
if (writes) {
556
- sequenceNode->appendChild (std::move (writeFilterNode));
511
+ insertExtensionAfter (
512
+ root,
513
+ tree,
514
+ tree->child ({0 }),
515
+ writeExtension,
516
+ std::move (writeFilterNode));
557
517
}
558
518
559
- auto extensionUmap = isl::union_map::empty (promotionSpace.params ());
560
- if (reads) {
561
- extensionUmap = extensionUmap.unite (readExtension);
562
- }
563
- if (writes) {
564
- extensionUmap = extensionUmap.unite (writeExtension);
565
- }
566
- auto extensionNode =
567
- ScheduleTree::makeExtension (extensionUmap, std::move (sequenceNode));
568
- tree->appendChild (std::move (extensionNode));
569
519
return tree;
570
520
}
571
521
} // namespace polyhedral
0 commit comments