@@ -311,20 +311,14 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
311
311
return {finder.reads , finder.writes };
312
312
}
313
313
314
- struct ScheduleTreeAndDomain {
315
- ScheduleTreeUPtr tree;
316
- isl::union_set domain;
317
- };
318
-
319
314
/*
320
- * Helper function for extracting a schedule tree from a Halide Stmt,
315
+ * Helper function for extracting a schedule from a Halide Stmt,
321
316
* recursively descending over the Stmt.
322
317
* "s" is the current position in the recursive descent.
323
318
* "set" describes the bounds on the outer loop iterators.
324
319
* "outer" contains the names of the outer loop iterators
325
320
* from outermost to innermost.
326
- * Return the schedule tree corresponding to the subtree at "s",
327
- * along with a separated out domain.
321
+ * Return the schedule corresponding to the subtree at "s".
328
322
*
329
323
* "reads" and "writes" collect the accesses found along the way.
330
324
* "accesses" collects the mapping from Call (for the reads) and Provide nodes
@@ -334,7 +328,7 @@ struct ScheduleTreeAndDomain {
334
328
* "iterators" collects the mapping from instance set tuple identifiers
335
329
* to the corresponding outer loop iterator names, from outermost to innermost.
336
330
*/
337
- ScheduleTreeAndDomain makeScheduleTreeHelper (
331
+ isl::schedule makeScheduleTreeHelper (
338
332
const Stmt& s,
339
333
isl::set set,
340
334
std::vector<std::string>& outer,
@@ -343,7 +337,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
343
337
AccessMap* accesses,
344
338
StatementMap* statements,
345
339
IteratorMap* iterators) {
346
- ScheduleTreeAndDomain result ;
340
+ isl::schedule schedule ;
347
341
if (auto op = s.as <For>()) {
348
342
// Add one additional dimension to our set of loop variables
349
343
int thisLoopIdx = set.dim (isl::dim_type::set);
@@ -397,7 +391,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
397
391
// dimension. The spaces may be different, but they'll all have
398
392
// this loop var at the same index.
399
393
isl::multi_union_pw_aff mupa;
400
- body.domain .foreach_set ([&](isl::set s) {
394
+ body.get_domain () .foreach_set ([&](isl::set s) {
401
395
isl::aff loopVar (
402
396
isl::local_space (s.get_space ()), isl::dim_type::set, thisLoopIdx);
403
397
if (mupa) {
@@ -407,12 +401,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
407
401
}
408
402
});
409
403
410
- if (body.tree ) {
411
- result.tree = ScheduleTree::makeBand (mupa, std::move (body.tree ));
412
- } else {
413
- result.tree = ScheduleTree::makeBand (mupa);
414
- }
415
- result.domain = body.domain ;
404
+ schedule = body.insert_partial_schedule (mupa);
416
405
} else if (auto op = s.as <Halide::Internal::Block>()) {
417
406
// Flatten a nested block. Halide Block statements always nest
418
407
// rightwards. Flattening it is not strictly necessary, but it
@@ -429,35 +418,16 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
429
418
stmts.push_back (r);
430
419
}
431
420
432
- // Build a schedule tree for each member of the block, then set up
433
- // appropriate filters that state which statements lie in which
434
- // children.
435
- std::vector<ScheduleTreeUPtr> trees;
421
+ // Build a schedule tree for each member of the block and
422
+ // collect them in a sequence.
436
423
for (Stmt s : stmts) {
437
424
auto mem = makeScheduleTreeHelper (
438
425
s, set, outer, reads, writes, accesses, statements, iterators);
439
- ScheduleTreeUPtr filter;
440
- if (mem.tree ) {
441
- // No statement instances are shared between the blocks, so we
442
- // can drop the constraints on the spaces. This makes the
443
- // schedule tree slightly simpler.
444
- filter = ScheduleTree::makeFilter (
445
- mem.domain .universe (), std::move (mem.tree ));
446
- } else {
447
- filter = ScheduleTree::makeFilter (mem.domain .universe ());
448
- }
449
- if (result.domain ) {
450
- result.domain = result.domain .unite (mem.domain );
426
+ if (schedule) {
427
+ schedule = schedule.sequence (mem);
451
428
} else {
452
- result. domain = mem. domain ;
429
+ schedule = mem;
453
430
}
454
- trees.push_back (std::move (filter));
455
- }
456
- CHECK_GE (trees.size (), 1 );
457
-
458
- result.tree = ScheduleTree::makeSequence (std::move (trees[0 ]));
459
- for (size_t i = 1 ; i < trees.size (); i++) {
460
- result.tree ->appendChild (std::move (trees[i]));
461
431
}
462
432
463
433
} else if (auto op = s.as <Provide>()) {
@@ -469,7 +439,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
469
439
statements->emplace (id, op);
470
440
iterators->emplace (id, outer);
471
441
isl::set domain = set.set_tuple_id (id);
472
- result. domain = domain;
442
+ schedule = isl::schedule::from_domain ( domain) ;
473
443
474
444
isl::union_map newReads, newWrites;
475
445
std::tie (newReads, newWrites) =
@@ -481,7 +451,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
481
451
} else {
482
452
LOG (FATAL) << " Unhandled Halide stmt: " << s;
483
453
}
484
- return result ;
454
+ return schedule ;
485
455
};
486
456
487
457
ScheduleTreeAndAccesses makeScheduleTree (isl::space paramSpace, const Stmt& s) {
@@ -491,7 +461,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
491
461
492
462
// Walk the IR building a schedule tree
493
463
std::vector<std::string> outer;
494
- auto treeAndDomain = makeScheduleTreeHelper (
464
+ auto schedule = makeScheduleTreeHelper (
495
465
s,
496
466
isl::set::universe (paramSpace),
497
467
outer,
@@ -501,16 +471,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
501
471
&result.statements ,
502
472
&result.iterators );
503
473
504
- // TODO: This fails if the stmt is just a Provide node, I'm not sure
505
- // what the schedule tree should look like in that case.
506
- CHECK (treeAndDomain.tree );
507
-
508
- // Add the outermost domain node
509
- result.tree = ScheduleTree::makeDomain (
510
- treeAndDomain.domain , std::move (treeAndDomain.tree ));
511
-
512
- // Check we have obeyed the ISL invariants
513
- checkValidIslSchedule (result.tree .get ());
474
+ result.tree = fromIslSchedule (schedule);
514
475
515
476
return result;
516
477
}
0 commit comments