@@ -333,6 +333,22 @@ std::pair<isl::union_map, isl::union_map> extractAccesses(
333
333
return {finder.reads , finder.writes };
334
334
}
335
335
336
+ /*
337
+ * Take a parametric expression "f" and convert it into an expression
338
+ * on the iteration domains in "domain" by reinterpreting the parameters
339
+ * as set dimensions according to the corresponding tuples in "map".
340
+ */
341
+ isl::union_pw_aff
342
+ onDomains (isl::aff f, isl::union_set domain, const IterationDomainMap& map) {
343
+ auto upa = isl::union_pw_aff::empty (domain.get_space ());
344
+ for (auto set : domain.get_set_list ()) {
345
+ auto tuple = map.at (set.get_tuple_id ()).tuple ;
346
+ auto onSet = isl::union_pw_aff (f.unbind_params_insert_domain (tuple));
347
+ upa = upa.union_add (onSet);
348
+ }
349
+ return upa;
350
+ }
351
+
336
352
} // namespace
337
353
338
354
/*
@@ -395,20 +411,12 @@ isl::schedule makeScheduleTreeHelper(
395
411
396
412
// Create an affine function that defines an ordering for all
397
413
// the statements in the body of this loop over the values of
398
- // this loop. For each statement in the children we want the
399
- // function that maps everything in its space to this
400
- // dimension. The spaces may be different, but they'll all have
401
- // this loop var at the same index.
402
- isl::multi_union_pw_aff mupa;
403
- body.get_domain ().foreach_set ([&](isl::set s) {
404
- isl::aff newLoopVar (
405
- isl::local_space (s.get_space ()), isl::dim_type::set, outer.n ());
406
- if (mupa) {
407
- mupa = mupa.union_add (isl::union_pw_aff (isl::pw_aff (newLoopVar)));
408
- } else {
409
- mupa = isl::union_pw_aff (isl::pw_aff (newLoopVar));
410
- }
411
- });
414
+ // this loop. Start from a parametric expression equal
415
+ // to the current loop iterator and then convert it to
416
+ // a function on the statements in the domain of the body schedule.
417
+ auto aff = isl::aff::param_on_domain_space (space, id);
418
+ auto domain = body.get_domain ();
419
+ auto mupa = isl::multi_union_pw_aff (onDomains (aff, domain, *domains));
412
420
413
421
schedule = body.insert_partial_schedule (mupa);
414
422
} else if (auto op = s.as <Halide::Internal::Block>()) {
0 commit comments