@@ -311,20 +311,14 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
311
311
return {finder.reads , finder.writes };
312
312
}
313
313
314
- struct ScheduleTreeAndDomain {
315
- ScheduleTreeUPtr tree;
316
- isl::union_set domain;
317
- };
318
-
319
314
/*
320
- * Helper function for extracting a schedule tree from a Halide Stmt,
315
+ * Helper function for extracting a schedule from a Halide Stmt,
321
316
* recursively descending over the Stmt.
322
317
* "s" is the current position in the recursive descent.
323
318
* "set" describes the bounds on the outer loop iterators.
324
319
* "outer" contains the names of the outer loop iterators
325
320
* from outermost to innermost.
326
- * Return the schedule tree corresponding to the subtree at "s",
327
- * along with a separated out domain.
321
+ * Return the schedule corresponding to the subtree at "s".
328
322
*
329
323
* "reads" and "writes" collect the accesses found along the way.
330
324
* "accesses" collects the mapping from Call (for the reads) and Provide nodes
@@ -334,7 +328,7 @@ struct ScheduleTreeAndDomain {
334
328
* "iterators" collects the mapping from instance set tuple identifiers
335
329
* to the corresponding outer loop iterator names, from outermost to innermost.
336
330
*/
337
- ScheduleTreeAndDomain makeScheduleTreeHelper (
331
+ isl::schedule makeScheduleTreeHelper (
338
332
const Stmt& s,
339
333
isl::set set,
340
334
std::vector<std::string>& outer,
@@ -343,7 +337,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
343
337
AccessMap* accesses,
344
338
StatementMap* statements,
345
339
IteratorMap* iterators) {
346
- ScheduleTreeAndDomain result ;
340
+ isl::schedule schedule ;
347
341
if (auto op = s.as <For>()) {
348
342
// Add one additional dimension to our set of loop variables
349
343
int thisLoopIdx = set.dim (isl::dim_type::set);
@@ -397,7 +391,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
397
391
// dimension. The spaces may be different, but they'll all have
398
392
// this loop var at the same index.
399
393
isl::multi_union_pw_aff mupa;
400
- body.domain .foreach_set ([&](isl::set s) {
394
+ body.get_domain () .foreach_set ([&](isl::set s) {
401
395
isl::aff loopVar (
402
396
isl::local_space (s.get_space ()), isl::dim_type::set, thisLoopIdx);
403
397
if (mupa) {
@@ -407,58 +401,20 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
407
401
}
408
402
});
409
403
410
- if (body.tree ) {
411
- result.tree = ScheduleTree::makeBand (mupa, std::move (body.tree ));
412
- } else {
413
- result.tree = ScheduleTree::makeBand (mupa);
414
- }
415
- result.domain = body.domain ;
404
+ schedule = body.insert_partial_schedule (mupa);
416
405
} else if (auto op = s.as <Halide::Internal::Block>()) {
417
- // Flatten a nested block. Halide Block statements always nest
418
- // rightwards. Flattening it is not strictly necessary, but it
419
- // keeps things uniform with the PET lowering path.
420
406
std::vector<Stmt> stmts;
421
407
stmts.push_back (op->first );
422
408
stmts.push_back (op->rest );
423
- while (const Halide::Internal::Block* b =
424
- stmts.back ().as <Halide::Internal::Block>()) {
425
- Stmt f = b->first ;
426
- Stmt r = b->rest ;
427
- stmts.pop_back ();
428
- stmts.push_back (f);
429
- stmts.push_back (r);
430
- }
431
409
432
- // Build a schedule tree for each member of the block, then set up
433
- // appropriate filters that state which statements lie in which
434
- // children.
435
- std::vector<ScheduleTreeUPtr> trees;
410
+ // Build a schedule tree for both members of the block and
411
+ // combine them in a sequence.
412
+ std::vector<isl::schedule> schedules;
436
413
for (Stmt s : stmts) {
437
- auto mem = makeScheduleTreeHelper (
438
- s, set, outer, reads, writes, accesses, statements, iterators);
439
- ScheduleTreeUPtr filter;
440
- if (mem.tree ) {
441
- // No statement instances are shared between the blocks, so we
442
- // can drop the constraints on the spaces. This makes the
443
- // schedule tree slightly simpler.
444
- filter = ScheduleTree::makeFilter (
445
- mem.domain .universe (), std::move (mem.tree ));
446
- } else {
447
- filter = ScheduleTree::makeFilter (mem.domain .universe ());
448
- }
449
- if (result.domain ) {
450
- result.domain = result.domain .unite (mem.domain );
451
- } else {
452
- result.domain = mem.domain ;
453
- }
454
- trees.push_back (std::move (filter));
455
- }
456
- CHECK_GE (trees.size (), 1 );
457
-
458
- result.tree = ScheduleTree::makeSequence (std::move (trees[0 ]));
459
- for (size_t i = 1 ; i < trees.size (); i++) {
460
- result.tree ->appendChild (std::move (trees[i]));
414
+ schedules.push_back (makeScheduleTreeHelper (
415
+ s, set, outer, reads, writes, accesses, statements, iterators));
461
416
}
417
+ schedule = schedules[0 ].sequence (schedules[1 ]);
462
418
463
419
} else if (auto op = s.as <Provide>()) {
464
420
// Make an ID for this leaf statement. This *is* semantically
@@ -469,7 +425,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
469
425
statements->emplace (id, op);
470
426
iterators->emplace (id, outer);
471
427
isl::set domain = set.set_tuple_id (id);
472
- result. domain = domain;
428
+ schedule = isl::schedule::from_domain ( domain) ;
473
429
474
430
isl::union_map newReads, newWrites;
475
431
std::tie (newReads, newWrites) =
@@ -481,7 +437,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
481
437
} else {
482
438
LOG (FATAL) << " Unhandled Halide stmt: " << s;
483
439
}
484
- return result ;
440
+ return schedule ;
485
441
};
486
442
487
443
ScheduleTreeAndAccesses makeScheduleTree (isl::space paramSpace, const Stmt& s) {
@@ -491,7 +447,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
491
447
492
448
// Walk the IR building a schedule tree
493
449
std::vector<std::string> outer;
494
- auto treeAndDomain = makeScheduleTreeHelper (
450
+ auto schedule = makeScheduleTreeHelper (
495
451
s,
496
452
isl::set::universe (paramSpace),
497
453
outer,
@@ -501,16 +457,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
501
457
&result.statements ,
502
458
&result.iterators );
503
459
504
- // TODO: This fails if the stmt is just a Provide node, I'm not sure
505
- // what the schedule tree should look like in that case.
506
- CHECK (treeAndDomain.tree );
507
-
508
- // Add the outermost domain node
509
- result.tree = ScheduleTree::makeDomain (
510
- treeAndDomain.domain , std::move (treeAndDomain.tree ));
511
-
512
- // Check we have obeyed the ISL invariants
513
- checkValidIslSchedule (result.tree .get ());
460
+ result.tree = fromIslSchedule (schedule);
514
461
515
462
return result;
516
463
}
0 commit comments