34
34
#include " tc/core/polyhedral/mapping_types.h"
35
35
#include " tc/core/polyhedral/schedule_tree_elem.h"
36
36
#include " tc/core/polyhedral/schedule_tree_matcher.h"
37
+ #include " tc/core/polyhedral/schedule_utils.h"
37
38
#include " tc/core/scope_guard.h"
38
39
#include " tc/external/isl.h"
39
40
@@ -47,164 +48,6 @@ namespace tc {
47
48
namespace polyhedral {
48
49
using namespace detail ;
49
50
50
- isl::union_map extendSchedule (
51
- const ScheduleTree* node,
52
- isl::union_map schedule) {
53
- if (auto bandElem = node->elemAs <ScheduleTreeElemBand>()) {
54
- if (bandElem->nMember () > 0 ) {
55
- schedule =
56
- schedule.flat_range_product (isl::union_map::from (bandElem->mupa_ ));
57
- }
58
- } else if (auto filterElem = node->elemAs <ScheduleTreeElemFilter>()) {
59
- schedule = schedule.intersect_domain (filterElem->filter_ );
60
- } else if (auto extensionElem = node->elemAs <ScheduleTreeElemExtension>()) {
61
- // FIXME: we may need to restrict the range of reversed extension map to
62
- // schedule values that correspond to active domain elements at this
63
- // point.
64
- schedule = schedule.unite (
65
- extensionElem->extension_ .reverse ().intersect_range (schedule.range ()));
66
- }
67
-
68
- return schedule;
69
- }
70
-
71
- namespace {
72
- isl::union_map partialScheduleImpl (
73
- const ScheduleTree* root,
74
- const ScheduleTree* node,
75
- bool useNode) {
76
- auto nodes = node->ancestors (root);
77
- if (useNode) {
78
- nodes.push_back (node);
79
- }
80
- TC_CHECK_GT (nodes.size (), 0u ) << " root node does not have a prefix schedule" ;
81
- auto domain = root->elemAs <ScheduleTreeElemDomain>();
82
- TC_CHECK (domain);
83
- auto schedule = isl::union_map::from_domain (domain->domain_ );
84
- for (auto anc : nodes) {
85
- if (anc->elemAs <ScheduleTreeElemDomain>()) {
86
- TC_CHECK (anc == root);
87
- } else {
88
- schedule = extendSchedule (anc, schedule);
89
- }
90
- }
91
- return schedule;
92
- }
93
- } // namespace
94
-
95
- isl::union_map prefixSchedule (
96
- const ScheduleTree* root,
97
- const ScheduleTree* node) {
98
- return partialScheduleImpl (root, node, false );
99
- }
100
-
101
- isl::union_map partialSchedule (
102
- const ScheduleTree* root,
103
- const ScheduleTree* node) {
104
- return partialScheduleImpl (root, node, true );
105
- }
106
-
107
- namespace {
108
- /*
109
- * If "node" is any filter, then intersect "domain" with that filter.
110
- */
111
- isl::union_set applyFilter (isl::union_set domain, const ScheduleTree* node) {
112
- if (auto filterElem = node->elemAs <ScheduleTreeElemFilter>()) {
113
- return domain.intersect (filterElem->filter_ );
114
- }
115
- return domain;
116
- }
117
-
118
- /*
119
- * If "node" is a mapping, then intersect "domain" with its filter.
120
- */
121
- isl::union_set applyMapping (isl::union_set domain, const ScheduleTree* node) {
122
- if (auto filterElem = node->elemAs <ScheduleTreeElemMapping>()) {
123
- return domain.intersect (filterElem->filter_ );
124
- }
125
- return domain;
126
- }
127
-
128
- // Get the set of domain elements that are active below
129
- // the given branch of nodes, filtered using "filter".
130
- //
131
- // Domain elements are introduced by the root domain node. Some nodes
132
- // refine this set of elements based on "filter". Extension nodes
133
- // are considered to introduce additional domain points.
134
- isl::union_set collectDomain (
135
- const ScheduleTree* root,
136
- const vector<const ScheduleTree*>& nodes,
137
- isl::union_set (*filter)(isl::union_set domain, const ScheduleTree* node)) {
138
- auto domainElem = root->elemAs <ScheduleTreeElemDomain>();
139
- TC_CHECK (domainElem) << " root must be a Domain node" << *root;
140
-
141
- auto domain = domainElem->domain_ ;
142
-
143
- for (auto anc : nodes) {
144
- domain = filter (domain, anc);
145
- if (auto extensionElem = anc->elemAs <ScheduleTreeElemExtension>()) {
146
- auto parentSchedule = prefixSchedule (root, anc);
147
- auto extension = extensionElem->extension_ ;
148
- TC_CHECK (parentSchedule) << " missing root domain node" ;
149
- parentSchedule = parentSchedule.intersect_domain (domain);
150
- domain = domain.unite (parentSchedule.range ().apply (extension));
151
- }
152
- }
153
- return domain;
154
- }
155
-
156
- // Get the set of domain elements that are active below
157
- // the given branch of nodes.
158
- isl::union_set activeDomainPointsHelper (
159
- const ScheduleTree* root,
160
- const vector<const ScheduleTree*>& nodes) {
161
- return collectDomain (root, nodes, &applyFilter);
162
- }
163
-
164
- } // namespace
165
-
166
- isl::union_set prefixMappingFilter (
167
- const ScheduleTree* root,
168
- const ScheduleTree* node) {
169
- return collectDomain (root, node->ancestors (root), &applyMapping);
170
- }
171
-
172
- isl::union_set activeDomainPoints (
173
- const ScheduleTree* root,
174
- const ScheduleTree* node) {
175
- return activeDomainPointsHelper (root, node->ancestors (root));
176
- }
177
-
178
- isl::union_set activeDomainPointsBelow (
179
- const ScheduleTree* root,
180
- const ScheduleTree* node) {
181
- auto ancestors = node->ancestors (root);
182
- ancestors.emplace_back (node);
183
- return activeDomainPointsHelper (root, ancestors);
184
- }
185
-
186
- vector<ScheduleTree*> collectScheduleTreesPath (
187
- std::function<ScheduleTree*(ScheduleTree*)> next,
188
- ScheduleTree* start) {
189
- vector<ScheduleTree*> res{start};
190
- auto n = start;
191
- while ((n = next (n)) != nullptr ) {
192
- res.push_back (n);
193
- }
194
- return res;
195
- }
196
-
197
- vector<const ScheduleTree*> collectScheduleTreesPath (
198
- std::function<const ScheduleTree*(const ScheduleTree*)> next,
199
- const ScheduleTree* start) {
200
- vector<const ScheduleTree*> res{start};
201
- auto n = start;
202
- while ((n = next (n)) != nullptr ) {
203
- res.push_back (n);
204
- }
205
- return res;
206
- }
207
-
208
51
// Replace "tree" in the list of its parent's children with newTree.
209
52
// Returns the pointer to newTree for call chaining purposes.
210
53
ScheduleTree* swapSubtree (
@@ -432,85 +275,6 @@ ScheduleTree* bandScale(ScheduleTree* tree, const vector<size_t>& scales) {
432
275
return tree;
433
276
}
434
277
435
- namespace {
436
-
437
- template <typename T>
438
- vector<T> reversed (const vector<T>& vec) {
439
- vector<T> result;
440
- result.reserve (vec.size ());
441
- result.insert (result.begin (), vec.rbegin (), vec.rend ());
442
- return result;
443
- }
444
-
445
- template <typename T>
446
- vector<const ScheduleTree*> filterType (const vector<const ScheduleTree*>& vec) {
447
- vector<const ScheduleTree*> result;
448
- for (auto e : vec) {
449
- if (e->elemAs <T>()) {
450
- result.push_back (e);
451
- }
452
- }
453
- return result;
454
- }
455
-
456
- template <typename T, typename Func>
457
- T foldl (const vector<const ScheduleTree*> vec, Func op, T init = T()) {
458
- T value = init;
459
- for (auto st : vec) {
460
- value = op (st, value);
461
- }
462
- return value;
463
- }
464
-
465
- template <typename ... Args>
466
- ostream& operator <<(ostream& os, const vector<Args...>& v) {
467
- os << " [" ;
468
- bool first = true ;
469
- for (auto const & ve : v) {
470
- if (!first) {
471
- os << " , " ;
472
- }
473
- os << ve;
474
- first = true ;
475
- }
476
- os << " ]" ;
477
- return os;
478
- }
479
- } // namespace
480
-
481
- isl::multi_union_pw_aff infixScheduleMupa (
482
- const ScheduleTree* root,
483
- const ScheduleTree* relativeRoot,
484
- const ScheduleTree* tree) {
485
- auto domainElem = root->elemAs <ScheduleTreeElemDomain>();
486
- TC_CHECK (domainElem);
487
- auto domain = domainElem->domain_ .universe ();
488
- auto zero = isl::multi_val::zero (domain.get_space ().set_from_params ());
489
- auto prefix = isl::multi_union_pw_aff (domain, zero);
490
- prefix = foldl (
491
- filterType<ScheduleTreeElemBand>(tree->ancestors (relativeRoot)),
492
- [](const ScheduleTree* st, isl::multi_union_pw_aff pref) {
493
- auto mupa = st->elemAs <ScheduleTreeElemBand>()->mupa_ ;
494
- return pref.flat_range_product (mupa);
495
- },
496
- prefix);
497
- return prefix;
498
- }
499
-
500
- isl::multi_union_pw_aff prefixScheduleMupa (
501
- const ScheduleTree* root,
502
- const ScheduleTree* tree) {
503
- return infixScheduleMupa (root, root, tree);
504
- }
505
-
506
- isl::multi_union_pw_aff partialScheduleMupa (
507
- const detail::ScheduleTree* root,
508
- const detail::ScheduleTree* tree) {
509
- auto prefix = prefixScheduleMupa (root, tree);
510
- auto band = tree->elemAs <ScheduleTreeElemBand>();
511
- return band ? prefix.flat_range_product (band->mupa_ ) : prefix;
512
- }
513
-
514
278
void updateTopLevelContext (detail::ScheduleTree* root, isl::set context) {
515
279
if (!matchOne (tc::polyhedral::domain (tc::polyhedral::context (any ())), root)) {
516
280
root->appendChild (ScheduleTree::makeContext (
@@ -832,55 +596,5 @@ void orderAfter(ScheduleTree* root, ScheduleTree* tree, isl::union_set filter) {
832
596
seq->insertChild (0 , gistedFilter (other, parent->detachChild (childPos)));
833
597
parent->insertChild (childPos, std::move (seq));
834
598
}
835
-
836
- /*
837
- * Extract a mapping from the domain elements active at "tree"
838
- * to identifiers "ids", where all branches in "tree"
839
- * are assumed to have been mapped to these identifiers.
840
- * The result lives in a space of the form "tupleId"["ids"...].
841
- */
842
- isl::multi_union_pw_aff extractDomainToIds (
843
- const detail::ScheduleTree* root,
844
- const detail::ScheduleTree* tree,
845
- const std::vector<mapping::MappingId>& ids,
846
- isl::id tupleId) {
847
- using namespace polyhedral ::detail;
848
-
849
- auto space = isl::space (tree->ctx_ , 0 );
850
- auto empty = isl::union_set::empty (space);
851
- space = space.named_set_from_params_id (tupleId, ids.size ());
852
- auto zero = isl::multi_val::zero (space);
853
- auto domainToIds = isl::multi_union_pw_aff (empty, zero);
854
-
855
- for (auto mapping : tree->collect (tree, ScheduleTreeType::Mapping)) {
856
- auto mappingNode = mapping->elemAs <ScheduleTreeElemMapping>();
857
- auto list = isl::union_pw_aff_list (tree->ctx_ , ids.size ());
858
- for (auto id : ids) {
859
- if (mappingNode->mapping .count (id) == 0 ) {
860
- break ;
861
- }
862
- auto idMap = mappingNode->mapping .at (id);
863
- list = list.add (idMap);
864
- }
865
- // Ignore this node if it does not map to all required ids.
866
- if (static_cast <size_t >(list.size ()) != ids.size ()) {
867
- continue ;
868
- }
869
- auto nodeToIds = isl::multi_union_pw_aff (space, list);
870
- auto active = activeDomainPoints (root, mapping);
871
- TC_CHECK (active.intersect (domainToIds.domain ()).is_empty ())
872
- << " conflicting mappings; are the filters in the tree disjoint?" ;
873
- nodeToIds = nodeToIds.intersect_domain (active);
874
- domainToIds = domainToIds.union_add (nodeToIds);
875
- }
876
-
877
- auto active = activeDomainPoints (root, tree);
878
- TC_CHECK (active.is_subset (domainToIds.domain ()))
879
- << " not all domain points of\n "
880
- << active << " \n were mapped to the required ids" ;
881
-
882
- return domainToIds;
883
- }
884
-
885
599
} // namespace polyhedral
886
600
} // namespace tc
0 commit comments