@@ -106,35 +106,51 @@ template <>
106
106
const CudaDim& mappingSize<mapping::ThreadId>(const MappedScop* mscop) {
107
107
return mscop->numThreads ;
108
108
}
109
+ } // namespace
110
+
111
+ template <typename MappingTypeId>
112
+ void MappedScop::mapRemaining (detail::ScheduleTree* tree, size_t nMapped) {
113
+ size_t nToMap = mappingSize<MappingTypeId>(this ).view .size ();
114
+ if (nMapped >= nToMap) {
115
+ return ;
116
+ }
109
117
110
- // Map the affine functions in "list" to the _new_ parameters
111
- // identified by ids[i] and limited by 0 <= ids[i] < extent[i]. The
112
- // parameters must not be present in the space of partial schedules in "list"
113
- // and extents must be non-zero. The mapping corresponds to inserting a filter
118
+ std::unordered_set<MappingTypeId, typename MappingTypeId::Hash> ids;
119
+ for (size_t i = nMapped; i < nToMap; ++i) {
120
+ ids.insert (MappingTypeId::makeId (i));
121
+ }
122
+ auto root = scop_->scheduleRoot ();
123
+ auto domain = activeDomainPoints (root, tree);
124
+ auto filter = makeFixRemainingZeroFilter (domain, ids);
125
+ auto mapping = detail::ScheduleTree::makeMappingFilter (filter, ids);
126
+ insertNodeAbove (root, tree, std::move (mapping));
127
+ }
128
+
129
+ // Map the elements in "list" to successive blocks or thread identifiers,
130
+ // with the first element mapped to identifier X. The extents are obtained
131
+ // from the initial elements of numBlocks or numThreads. The identifiers
132
+ // must not be present in the space of the partial schedules in "list" and
133
+ // extents must be non-zero. The mapping corresponds to inserting a filter
114
134
// node with condition 'list % extent = ids'.
115
- // The number of elements in "list" and "ids" needs to be the same,
116
- // but "extent" is allowed to have extra elements, in which case
117
- // only the initial elements are used.
118
135
// The mapping is inserted above "tree".
119
136
//
120
- // Returns a pointer to the updated node (below the inserted filter)
137
+ // Return a pointer to the updated node (below the inserted filter)
121
138
// for call chaining purposes.
122
- template <typename MappingIdType>
123
- detail::ScheduleTree* mapToParametersWithExtent (
124
- detail::ScheduleTree* root,
139
+ template <typename MappingTypeId>
140
+ detail::ScheduleTree* MappedScop::map (
125
141
detail::ScheduleTree* tree,
126
- isl::union_pw_aff_list list,
127
- const std::vector<MappingIdType>& ids,
128
- const CudaDimView& extent) {
129
- CHECK_EQ (ids.size (), list.n ());
130
- CHECK_LE (list.n (), extent.size ()) << " dimension overflow" ;
142
+ isl::union_pw_aff_list list) {
143
+ size_t nToMap = list.n ();
144
+ const auto & extent = mappingSize<MappingTypeId>(this ).view ;
145
+ CHECK_LE (nToMap, extent.size ()) << " dimension overflow" ;
131
146
147
+ auto root = scop_->scheduleRoot ();
132
148
auto domain = activeDomainPoints (root, tree).universe ();
133
149
auto filter = domain;
134
150
135
- std::unordered_set<MappingIdType , typename MappingIdType ::Hash> idSet;
136
- for (size_t i = 0 ; i < ids. size () ; ++i) {
137
- auto id = ids[i] ;
151
+ std::unordered_set<MappingTypeId , typename MappingTypeId ::Hash> idSet;
152
+ for (size_t i = 0 ; i < nToMap ; ++i) {
153
+ auto id = MappingTypeId::makeId (i) ;
138
154
auto upa = list.get (i);
139
155
// Introduce the "mapping" parameter after checking it is not already
140
156
// present in the schedule space.
@@ -151,26 +167,10 @@ detail::ScheduleTree* mapToParametersWithExtent(
151
167
}
152
168
153
169
auto mapping = detail::ScheduleTree::makeMappingFilter (filter, idSet);
154
- return insertNodeAbove (root, tree, std::move (mapping))->child ({0 });
155
- }
156
- } // namespace
157
-
158
- template <typename MappingTypeId>
159
- void MappedScop::mapRemaining (detail::ScheduleTree* tree, size_t nMapped) {
160
- size_t nToMap = mappingSize<MappingTypeId>(this ).view .size ();
161
- if (nMapped >= nToMap) {
162
- return ;
163
- }
170
+ tree = insertNodeAbove (root, tree, std::move (mapping))->child ({0 });
164
171
165
- std::unordered_set<MappingTypeId, typename MappingTypeId::Hash> ids;
166
- for (size_t i = nMapped; i < nToMap; ++i) {
167
- ids.insert (MappingTypeId::makeId (i));
168
- }
169
- auto root = scop_->scheduleRoot ();
170
- auto domain = activeDomainPoints (root, tree);
171
- auto filter = makeFixRemainingZeroFilter (domain, ids);
172
- auto mapping = detail::ScheduleTree::makeMappingFilter (filter, ids);
173
- insertNodeAbove (root, tree, std::move (mapping));
172
+ mapRemaining<MappingTypeId>(tree, nToMap);
173
+ return tree;
174
174
}
175
175
176
176
detail::ScheduleTree* MappedScop::mapBlocksForward (
@@ -179,16 +179,9 @@ detail::ScheduleTree* MappedScop::mapBlocksForward(
179
179
auto bandNode = band->elemAs <detail::ScheduleTreeElemBand>();
180
180
CHECK (bandNode) << " expected a band, got " << *band;
181
181
182
- auto root = scop_->scheduleRoot ();
183
- std::vector<mapping::BlockId> mapped;
184
- for (size_t i = 0 ; i < nToMap; ++i) {
185
- mapped.emplace_back (mapping::BlockId::makeId (i));
186
- }
187
182
auto list = bandNode->mupa_ .get_union_pw_aff_list ();
188
183
list = list.drop (nToMap, list.n () - nToMap);
189
- band = mapToParametersWithExtent (root, band, list, mapped, numBlocks.view );
190
- mapRemaining<mapping::BlockId>(band, nToMap);
191
- return band;
184
+ return map<mapping::BlockId>(band, list);
192
185
}
193
186
194
187
// Uses as many blockSizes elements as outer coincident dimensions in the
@@ -375,16 +368,9 @@ detail::ScheduleTree* MappedScop::mapThreadsBackward(
375
368
auto ctx = band->ctx_ ;
376
369
insertNodeBelow (band, detail::ScheduleTree::makeThreadSpecificMarker (ctx));
377
370
378
- auto root = scop_->scheduleRoot ();
379
- std::vector<mapping::ThreadId> mapped;
380
- for (size_t i = 0 ; i < nToMap; ++i) {
381
- mapped.emplace_back (mapping::ThreadId::makeId (i));
382
- }
383
371
auto list = bandNode->mupa_ .get_union_pw_aff_list ().reverse ();
384
372
list = list.drop (nToMap, list.n () - nToMap);
385
- band = mapToParametersWithExtent (root, band, list, mapped, numThreads.view );
386
- mapRemaining<mapping::ThreadId>(band, nToMap);
387
- return band;
373
+ return map<mapping::ThreadId>(band, list);
388
374
}
389
375
390
376
size_t MappedScop::mapToThreads (detail::ScheduleTree* band) {
0 commit comments