@@ -200,37 +200,6 @@ isl::union_map fullSchedule(const detail::ScheduleTree* root) {
200
200
return schedule;
201
201
}
202
202
203
- /*
204
- * Insert map constraints that equate first "nDims" input dimensions to newly
205
- * introduced parameters.
206
- */
207
- isl::map fixOuterInputDimsAsParameters (isl::map map, unsigned nDims) {
208
- if (nDims < 0 || nDims > map.dim (isl::dim_type::in)) {
209
- std::stringstream ss;
210
- ss << nDims << " is out of [0, " << map.dim (isl::dim_type::in)
211
- << " ) range" ;
212
- throw promotion::OutOfRangeException (ss.str ());
213
- }
214
-
215
- auto fixedMap = map;
216
- auto localSpace = isl::local_space (map.get_space ().domain ());
217
- auto nParams = map.dim (isl::dim_type::param);
218
- localSpace = localSpace.add_dims (isl::dim_type::param, nDims);
219
- for (unsigned i = 0 ; i < nDims; ++i) {
220
- localSpace = localSpace.set_dim_name (
221
- isl::dim_type::param,
222
- nParams + i,
223
- " __tcFixerParam" + std::to_string (i));
224
- }
225
- for (unsigned i = 0 ; i < nDims; ++i) {
226
- auto left = isl::aff (localSpace, isl::dim_type::param, nParams + i);
227
- auto right = isl::aff (localSpace, isl::dim_type::set, i);
228
- auto dom = isl::aff_set (left) == right;
229
- fixedMap = fixedMap.intersect_domain (dom);
230
- }
231
- return fixedMap;
232
- }
233
-
234
203
/*
235
204
* Check if a reference group features reuse within the "outer" schedule.
236
205
* In particular, check that for some given point in the outer schedule and
@@ -339,19 +308,25 @@ bool promotionImprovesCoalescing(
339
308
}
340
309
341
310
/*
342
- * Check if the given "group" can be promoted to registers for the given active
343
- * domain points under full "schedule" where "nThreads" consecutive dimensions
344
- * at "depth"
345
- * are mapped to threads (the innermost of them being mapped to thread x).
311
+ * Check if the given "group" can be promoted to registers for the given
312
+ * mapping to thread identifiers and within the given outer schedule.
346
313
*
347
314
* In particular, the group's footprint must contain only one element and the
348
- * same tensor element should never be accessed by two different threads.
315
+ * same tensor element should never be accessed by two different threads
316
+ * within the same iteration of the outer schedule.
317
+ * The second test is performed by checking that there is only a single
318
+ * thread associated to a given pair of tensor element and outer schedule
319
+ * iteration.
320
+ * Note that the test for a single thread is performed by looking
321
+ * at the range of "thread". This range may be larger than the number
322
+ * of threads, such that multiple instances may get mapped to the same thread.
323
+ * Requiring different such instances is therefore slightly more conservative
324
+ * than strictly needed.
349
325
*/
350
326
bool isPromotableToRegisterBelowThreads (
351
327
const TensorReferenceGroup& group,
352
- isl::union_map schedule,
353
- size_t depth,
354
- size_t nThreads) {
328
+ isl::multi_union_pw_aff outer,
329
+ isl::multi_union_pw_aff thread) {
355
330
auto originalAccesses = group.originalAccesses ();
356
331
357
332
// Return early if more than one element needs to be stored in registers.
@@ -364,28 +339,11 @@ bool isPromotableToRegisterBelowThreads(
364
339
return false ;
365
340
}
366
341
367
- auto scheduledAccesses = originalAccesses.apply_domain (schedule);
368
-
369
- // Scheduled accesses contain maps from schedule dimensions to tensor
370
- // subscripts. Compute the relation between the schedule dimensions
371
- // mapped to threads and tensor subscripts by first removing dimensions
372
- // following the one mapped to thread x (last one assuming inverse mapping
373
- // order), then by equating all dimensions not mapped to threads to
374
- // parameters. Promotion to registers is only allowed if the resulting
375
- // relation is injective, i.e. the same tensor element is never accessed by
376
- // more than one thread. Note that our current check is overly conservative
377
- // because different values of schedule dimension may get mapped to the same
378
- // thread, in which case they could access the same tensor element.
379
- for (auto sa : isl::UnionAsVector<isl::union_map>(scheduledAccesses)) {
380
- sa = sa.project_out (
381
- isl::dim_type::in, depth, sa.dim (isl::dim_type::in) - depth);
382
- sa = fixOuterInputDimsAsParameters (sa, depth - nThreads);
383
- if (!sa.is_injective ()) {
384
- return false ;
385
- }
386
- }
342
+ auto map = isl::union_map::from (outer);
343
+ map = map.range_product (group.originalAccesses ());
344
+ map = map.apply_domain (isl::union_map::from (thread));
387
345
388
- return true ;
346
+ return map. is_injective () ;
389
347
}
390
348
391
349
/*
@@ -573,22 +531,16 @@ void promoteToRegistersBelowThreads(Scop& scop, size_t nRegisters) {
573
531
574
532
auto root = scop.scheduleRoot ();
575
533
576
- auto fullSched = fullSchedule (root);
577
534
{
578
535
auto markers = findThreadSpecificMarkers (root);
579
536
580
537
for (auto marker : markers) {
581
538
auto partialSched = prefixSchedule (root, marker);
582
539
// Pure affine schedule without (mapping) filters.
583
- auto partialSchedMupa = prefixScheduleMupa (root, marker);
584
-
585
- auto depth = marker->scheduleDepth (root);
586
-
587
- // Thread mapping filters are inserted immediately above the members
588
- // mapped to threads. The number of intermediate band members
589
- // is therefore equal to the number of mapped thread identifiers.
590
540
auto mapping = findThreadMappingAncestor (root, marker);
591
- size_t nMappedThreads = marker->scheduleDepth (mapping);
541
+ auto prefixSchedMupa = prefixScheduleMupa (root, mapping);
542
+ auto mapSchedMupa = infixScheduleMupa (root, mapping, marker);
543
+ auto partialSchedMupa = prefixSchedMupa.flat_range_product (mapSchedMupa);
592
544
593
545
auto groupMap = TensorReferenceGroup::accessedBySubtree (marker, scop);
594
546
for (auto & tensorGroups : groupMap) {
@@ -603,7 +555,7 @@ void promoteToRegistersBelowThreads(Scop& scop, size_t nRegisters) {
603
555
continue ;
604
556
}
605
557
if (!isPromotableToRegisterBelowThreads (
606
- *group, fullSched, depth, nMappedThreads )) {
558
+ *group, prefixSchedMupa, mapSchedMupa )) {
607
559
continue ;
608
560
}
609
561
if (!hasReuseWithin (*group, partialSchedMupa)) {
0 commit comments