Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 628f282

Browse files
Merge pull request #231 from facebookresearch/pr/halide_schedule
halide2isl::makeScheduleTree: use isl::schedule for constructing tree
2 parents 5d5965c + 806e3b0 commit 628f282

File tree

1 file changed

+16
-69
lines changed

1 file changed

+16
-69
lines changed

src/core/halide2isl.cc

Lines changed: 16 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -311,20 +311,14 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
311311
return {finder.reads, finder.writes};
312312
}
313313

314-
struct ScheduleTreeAndDomain {
315-
ScheduleTreeUPtr tree;
316-
isl::union_set domain;
317-
};
318-
319314
/*
320-
* Helper function for extracting a schedule tree from a Halide Stmt,
315+
* Helper function for extracting a schedule from a Halide Stmt,
321316
* recursively descending over the Stmt.
322317
* "s" is the current position in the recursive descent.
323318
* "set" describes the bounds on the outer loop iterators.
324319
* "outer" contains the names of the outer loop iterators
325320
* from outermost to innermost.
326-
* Return the schedule tree corresponding to the subtree at "s",
327-
* along with a separated out domain.
321+
* Return the schedule corresponding to the subtree at "s".
328322
*
329323
* "reads" and "writes" collect the accesses found along the way.
330324
* "accesses" collects the mapping from Call (for the reads) and Provide nodes
@@ -334,7 +328,7 @@ struct ScheduleTreeAndDomain {
334328
* "iterators" collects the mapping from instance set tuple identifiers
335329
* to the corresponding outer loop iterator names, from outermost to innermost.
336330
*/
337-
ScheduleTreeAndDomain makeScheduleTreeHelper(
331+
isl::schedule makeScheduleTreeHelper(
338332
const Stmt& s,
339333
isl::set set,
340334
std::vector<std::string>& outer,
@@ -343,7 +337,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
343337
AccessMap* accesses,
344338
StatementMap* statements,
345339
IteratorMap* iterators) {
346-
ScheduleTreeAndDomain result;
340+
isl::schedule schedule;
347341
if (auto op = s.as<For>()) {
348342
// Add one additional dimension to our set of loop variables
349343
int thisLoopIdx = set.dim(isl::dim_type::set);
@@ -397,7 +391,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
397391
// dimension. The spaces may be different, but they'll all have
398392
// this loop var at the same index.
399393
isl::multi_union_pw_aff mupa;
400-
body.domain.foreach_set([&](isl::set s) {
394+
body.get_domain().foreach_set([&](isl::set s) {
401395
isl::aff loopVar(
402396
isl::local_space(s.get_space()), isl::dim_type::set, thisLoopIdx);
403397
if (mupa) {
@@ -407,58 +401,20 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
407401
}
408402
});
409403

410-
if (body.tree) {
411-
result.tree = ScheduleTree::makeBand(mupa, std::move(body.tree));
412-
} else {
413-
result.tree = ScheduleTree::makeBand(mupa);
414-
}
415-
result.domain = body.domain;
404+
schedule = body.insert_partial_schedule(mupa);
416405
} else if (auto op = s.as<Halide::Internal::Block>()) {
417-
// Flatten a nested block. Halide Block statements always nest
418-
// rightwards. Flattening it is not strictly necessary, but it
419-
// keeps things uniform with the PET lowering path.
420406
std::vector<Stmt> stmts;
421407
stmts.push_back(op->first);
422408
stmts.push_back(op->rest);
423-
while (const Halide::Internal::Block* b =
424-
stmts.back().as<Halide::Internal::Block>()) {
425-
Stmt f = b->first;
426-
Stmt r = b->rest;
427-
stmts.pop_back();
428-
stmts.push_back(f);
429-
stmts.push_back(r);
430-
}
431409

432-
// Build a schedule tree for each member of the block, then set up
433-
// appropriate filters that state which statements lie in which
434-
// children.
435-
std::vector<ScheduleTreeUPtr> trees;
410+
// Build a schedule tree for both members of the block and
411+
// combine them in a sequence.
412+
std::vector<isl::schedule> schedules;
436413
for (Stmt s : stmts) {
437-
auto mem = makeScheduleTreeHelper(
438-
s, set, outer, reads, writes, accesses, statements, iterators);
439-
ScheduleTreeUPtr filter;
440-
if (mem.tree) {
441-
// No statement instances are shared between the blocks, so we
442-
// can drop the constraints on the spaces. This makes the
443-
// schedule tree slightly simpler.
444-
filter = ScheduleTree::makeFilter(
445-
mem.domain.universe(), std::move(mem.tree));
446-
} else {
447-
filter = ScheduleTree::makeFilter(mem.domain.universe());
448-
}
449-
if (result.domain) {
450-
result.domain = result.domain.unite(mem.domain);
451-
} else {
452-
result.domain = mem.domain;
453-
}
454-
trees.push_back(std::move(filter));
455-
}
456-
CHECK_GE(trees.size(), 1);
457-
458-
result.tree = ScheduleTree::makeSequence(std::move(trees[0]));
459-
for (size_t i = 1; i < trees.size(); i++) {
460-
result.tree->appendChild(std::move(trees[i]));
414+
schedules.push_back(makeScheduleTreeHelper(
415+
s, set, outer, reads, writes, accesses, statements, iterators));
461416
}
417+
schedule = schedules[0].sequence(schedules[1]);
462418

463419
} else if (auto op = s.as<Provide>()) {
464420
// Make an ID for this leaf statement. This *is* semantically
@@ -469,7 +425,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
469425
statements->emplace(id, op);
470426
iterators->emplace(id, outer);
471427
isl::set domain = set.set_tuple_id(id);
472-
result.domain = domain;
428+
schedule = isl::schedule::from_domain(domain);
473429

474430
isl::union_map newReads, newWrites;
475431
std::tie(newReads, newWrites) =
@@ -481,7 +437,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
481437
} else {
482438
LOG(FATAL) << "Unhandled Halide stmt: " << s;
483439
}
484-
return result;
440+
return schedule;
485441
};
486442

487443
ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
@@ -491,7 +447,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
491447

492448
// Walk the IR building a schedule tree
493449
std::vector<std::string> outer;
494-
auto treeAndDomain = makeScheduleTreeHelper(
450+
auto schedule = makeScheduleTreeHelper(
495451
s,
496452
isl::set::universe(paramSpace),
497453
outer,
@@ -501,16 +457,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
501457
&result.statements,
502458
&result.iterators);
503459

504-
// TODO: This fails if the stmt is just a Provide node, I'm not sure
505-
// what the schedule tree should look like in that case.
506-
CHECK(treeAndDomain.tree);
507-
508-
// Add the outermost domain node
509-
result.tree = ScheduleTree::makeDomain(
510-
treeAndDomain.domain, std::move(treeAndDomain.tree));
511-
512-
// Check we have obeyed the ISL invariants
513-
checkValidIslSchedule(result.tree.get());
460+
result.tree = fromIslSchedule(schedule);
514461

515462
return result;
516463
}

0 commit comments

Comments
 (0)