@@ -107,42 +107,50 @@ const CudaDim& mappingSize<mapping::ThreadId>(const MappedScop* mscop) {
107
107
return mscop->numThreads ;
108
108
}
109
109
110
- // Map "pos"-th schedule dimension of the band node identified by "tree" to a
111
- // _new_ parameter identified by "id" and limited by 0 <= id < extent. The
112
- // parameter must not be present in the space of partial schedule of "tree" and
113
- // extent must be non-zero. The mapping corresponds to inserting a filter
114
- // node with condition 'dim % extent = id' where dim is "pos"-th
115
- // schedule dimension.
110
+ // Map the affine functions in "list" to the _new_ parameters
111
+ // identified by ids[i] and limited by 0 <= ids[i] < extent[i]. The
112
+ // parameters must not be present in the space of partial schedules in "list"
113
+ // and extents must be non-zero. The mapping corresponds to inserting a filter
114
+ // node with condition 'list % extent = ids'.
115
+ // The number of elements in "list" and "ids" needs to be the same,
116
+ // but "extent" is allowed to have extra elements, in which case
117
+ // only the initial elements are used.
118
+ // The mapping is inserted above "tree".
116
119
//
117
- // Returns a pointer to the updated band (below the inserted filter)
120
+ // Returns a pointer to the updated node (below the inserted filter)
118
121
// for call chaining purposes.
119
122
template <typename MappingIdType>
120
- detail::ScheduleTree* mapToParameterWithExtent (
123
+ detail::ScheduleTree* mapToParametersWithExtent (
121
124
detail::ScheduleTree* root,
122
125
detail::ScheduleTree* tree,
123
- size_t pos,
124
- MappingIdType id,
125
- size_t extent) {
126
- auto band = tree->elemAs <detail::ScheduleTreeElemBand>();
127
- CHECK (band) << " expected a band, got " << *tree;
128
- CHECK_GE (pos, 0u ) << " dimension underflow" ;
129
- CHECK_LT (pos, band->nMember ()) << " dimension overflow" ;
130
- CHECK_NE (extent, 0u ) << " NYI: mapping to 0" ;
126
+ isl::union_pw_aff_list list,
127
+ const std::vector<MappingIdType>& ids,
128
+ const CudaDimView& extent) {
129
+ CHECK_EQ (ids.size (), list.n ());
130
+ CHECK_LE (list.n (), extent.size ()) << " dimension overflow" ;
131
131
132
132
auto domain = activeDomainPoints (root, tree).universe ();
133
+ auto filter = domain;
134
+
135
+ std::unordered_set<MappingIdType, typename MappingIdType::Hash> idSet;
136
+ for (size_t i = 0 ; i < ids.size (); ++i) {
137
+ auto id = ids[i];
138
+ auto upa = list.get (i);
139
+ // Introduce the "mapping" parameter after checking it is not already
140
+ // present in the schedule space.
141
+ CHECK (not upa.involves_param (id));
142
+ CHECK_NE (extent[i], 0u ) << " NYI: mapping to 0" ;
133
143
134
- // Introduce the "mapping" parameter after checking it is not already present
135
- // in the schedule space.
136
- CHECK (not band->mupa_ .involves_param (id));
137
-
138
- // Create mapping filter by equating the newly introduced
139
- // parameter "id" to the "pos"-th schedule dimension modulo its extent.
140
- auto upa =
141
- band->mupa_ .get_union_pw_aff (pos).mod_val (isl::val (tree->ctx_ , extent));
142
- upa = upa.sub (isl::union_pw_aff::param_on_domain (domain, id));
143
- auto filter = upa.zero_union_set ();
144
- auto mapping =
145
- detail::ScheduleTree::makeMappingFilter<MappingIdType>(filter, {id});
144
+ // Create mapping filter by equating the newly introduced
145
+ // parameter ids[i] to the "i"-th affine function modulo its extent.
146
+ upa = upa.mod_val (isl::val (tree->ctx_ , extent[i]));
147
+ upa = upa.sub (isl::union_pw_aff::param_on_domain (domain, id));
148
+ filter = filter.intersect (upa.zero_union_set ());
149
+
150
+ idSet.emplace (id);
151
+ }
152
+
153
+ auto mapping = detail::ScheduleTree::makeMappingFilter (filter, idSet);
146
154
return insertNodeAbove (root, tree, std::move (mapping))->child ({0 });
147
155
}
148
156
} // namespace
@@ -168,11 +176,17 @@ void MappedScop::mapRemaining(detail::ScheduleTree* tree, size_t nMapped) {
168
176
detail::ScheduleTree* MappedScop::mapBlocksForward (
169
177
detail::ScheduleTree* band,
170
178
size_t nToMap) {
179
+ auto bandNode = band->elemAs <detail::ScheduleTreeElemBand>();
180
+ CHECK (bandNode) << " expected a band, got " << *band;
181
+
171
182
auto root = scop_->scheduleRoot ();
183
+ std::vector<mapping::BlockId> mapped;
172
184
for (size_t i = 0 ; i < nToMap; ++i) {
173
- auto id = mapping::BlockId::makeId (i);
174
- band = mapToParameterWithExtent (root, band, i, id, numBlocks.view [i]);
185
+ mapped.emplace_back (mapping::BlockId::makeId (i));
175
186
}
187
+ auto list = bandNode->mupa_ .get_union_pw_aff_list ();
188
+ list = list.drop (nToMap, list.n () - nToMap);
189
+ band = mapToParametersWithExtent (root, band, list, mapped, numBlocks.view );
176
190
mapRemaining<mapping::BlockId>(band, nToMap);
177
191
return band;
178
192
}
@@ -362,11 +376,13 @@ detail::ScheduleTree* MappedScop::mapThreadsBackward(
362
376
insertNodeBelow (band, detail::ScheduleTree::makeThreadSpecificMarker (ctx));
363
377
364
378
auto root = scop_->scheduleRoot ();
379
+ std::vector<mapping::ThreadId> mapped;
365
380
for (size_t i = 0 ; i < nToMap; ++i) {
366
- auto id = mapping::ThreadId::makeId (i);
367
- auto pos = nMember - 1 - i;
368
- band = mapToParameterWithExtent (root, band, pos, id, numThreads.view [i]);
381
+ mapped.emplace_back (mapping::ThreadId::makeId (i));
369
382
}
383
+ auto list = bandNode->mupa_ .get_union_pw_aff_list ().reverse ();
384
+ list = list.drop (nToMap, list.n () - nToMap);
385
+ band = mapToParametersWithExtent (root, band, list, mapped, numThreads.view );
370
386
mapRemaining<mapping::ThreadId>(band, nToMap);
371
387
return band;
372
388
}
0 commit comments