Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit bc63ab4

Browse files
author
Sven Verdoolaege
committed
store data extracted from statement bodies in Body structure
This data is collected by makeScheduleTree and then copied into and between Scops. It is easier to do this collection on a single object rather than on multiple objects. For now, the store information consists of two objects, the read and write accesses, but this will be extended with a polyhedral representation of reductions. The fields of Body also support some common operations, supporting the fact that they belong together.
1 parent 9a8fbbc commit bc63ab4

13 files changed

+121
-44
lines changed

tc/core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_library(
2121
halide2isl.cc
2222
halide_utils.cc
2323

24+
polyhedral/body.cc
2425
polyhedral/codegen.cc
2526
polyhedral/memory_promotion.cc
2627
polyhedral/reduction_matcher.cc

tc/core/halide2isl.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "tc/core/check.h"
2222
#include "tc/core/constants.h"
23+
#include "tc/core/polyhedral/body.h"
2324
#include "tc/core/polyhedral/schedule_isl_conversion.h"
2425
#include "tc/core/polyhedral/schedule_transforms.h"
2526
#include "tc/core/polyhedral/schedule_tree.h"
@@ -360,7 +361,7 @@ onDomains(isl::aff f, isl::union_set domain, const IterationDomainMap& map) {
360361
* from outermost to innermost.
361362
* Return the schedule corresponding to the subtree at "s".
362363
*
363-
* "reads" and "writes" collect the accesses found along the way.
364+
* "body" collects the accesses found along the way.
364365
* "accesses" collects the mapping from Call (for the reads) and Provide nodes
365366
* (for the writes) to the corresponding tag in the access relations.
366367
* "statements" collects the mapping from instance set tuple identifiers
@@ -372,8 +373,7 @@ isl::schedule makeScheduleTreeHelper(
372373
const Stmt& s,
373374
isl::set set,
374375
isl::id_list outer,
375-
isl::union_map* reads,
376-
isl::union_map* writes,
376+
Body* body,
377377
AccessMap* accesses,
378378
StatementMap* statements,
379379
IterationDomainMap* domains) {
@@ -406,19 +406,19 @@ isl::schedule makeScheduleTreeHelper(
406406

407407
// Recursively descend.
408408
auto outerNext = outer.add(isl::id(set.get_ctx(), op->name));
409-
auto body = makeScheduleTreeHelper(
410-
op->body, set, outerNext, reads, writes, accesses, statements, domains);
409+
auto bodySchedule = makeScheduleTreeHelper(
410+
op->body, set, outerNext, body, accesses, statements, domains);
411411

412412
// Create an affine function that defines an ordering for all
413413
// the statements in the body of this loop over the values of
414414
// this loop. Start from a parametric expression equal
415415
// to the current loop iterator and then convert it to
416416
// a function on the statements in the domain of the body schedule.
417417
auto aff = isl::aff::param_on_domain_space(space, id);
418-
auto domain = body.get_domain();
418+
auto domain = bodySchedule.get_domain();
419419
auto mupa = isl::multi_union_pw_aff(onDomains(aff, domain, *domains));
420420

421-
schedule = body.insert_partial_schedule(mupa);
421+
schedule = bodySchedule.insert_partial_schedule(mupa);
422422
} else if (auto op = s.as<Halide::Internal::Block>()) {
423423
std::vector<Stmt> stmts;
424424
stmts.push_back(op->first);
@@ -429,7 +429,7 @@ isl::schedule makeScheduleTreeHelper(
429429
std::vector<isl::schedule> schedules;
430430
for (Stmt stmt : stmts) {
431431
schedules.push_back(makeScheduleTreeHelper(
432-
stmt, set, outer, reads, writes, accesses, statements, domains));
432+
stmt, set, outer, body, accesses, statements, domains));
433433
}
434434
schedule = schedules[0].sequence(schedules[1]);
435435

@@ -453,8 +453,8 @@ isl::schedule makeScheduleTreeHelper(
453453
std::tie(newReads, newWrites) =
454454
extractAccesses(iterationDomain, op, accesses);
455455

456-
*reads = reads->unite(newReads);
457-
*writes = writes->unite(newWrites);
456+
body->reads = body->reads.unite(newReads);
457+
body->writes = body->writes.unite(newWrites);
458458

459459
} else {
460460
LOG(FATAL) << "Unhandled Halide stmt: " << s;
@@ -465,20 +465,20 @@ isl::schedule makeScheduleTreeHelper(
465465
ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
466466
ScheduleTreeAndAccesses result;
467467

468-
result.writes = result.reads = isl::union_map::empty(paramSpace);
468+
Body body(paramSpace);
469469

470470
// Walk the IR building a schedule tree
471471
isl::id_list outer(paramSpace.get_ctx(), 0);
472472
auto schedule = makeScheduleTreeHelper(
473473
s,
474474
isl::set::universe(paramSpace),
475475
outer,
476-
&result.reads,
477-
&result.writes,
476+
&body,
478477
&result.accesses,
479478
&result.statements,
480479
&result.domains);
481480

481+
result.body = body;
482482
result.tree = fromIslSchedule(schedule);
483483

484484
return result;

tc/core/halide2isl.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include <Halide.h>
2424

25+
#include "tc/core/polyhedral/body.h"
2526
#include "tc/core/polyhedral/schedule_tree.h"
2627
#include "tc/core/tc2halide.h"
2728
#include "tc/external/isl.h"
@@ -82,10 +83,8 @@ struct ScheduleTreeAndAccesses {
8283
/// for each leaf node is captured below.
8384
tc::polyhedral::ScheduleTreeUPtr tree;
8485

85-
/// Union maps describing the reads and writes done. Uses the ids in
86-
/// the schedule tree to denote the containing Stmt, and tags each
87-
/// access with a unique reference id of the form __tc_ref_N.
88-
isl::union_map reads, writes;
86+
/// Information extracted from the bodies of the statements.
87+
tc::polyhedral::Body body;
8988

9089
/// The correspondence between from Call and Provide nodes and the
9190
/// reference ids in the reads and writes maps.

tc/core/polyhedral/body.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/**
2+
* Copyright (c) 2018, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tc/core/polyhedral/body.h"
18+
19+
#include <iostream>
20+
21+
#include "tc/external/isl.h"
22+
23+
namespace tc {
24+
namespace polyhedral {
25+
26+
std::ostream& operator<<(std::ostream& os, const Body& body) {
27+
os << "reads: " << body.reads << "\n";
28+
os << "writes: " << body.writes << "\n";
29+
30+
return os;
31+
}
32+
33+
} // namespace polyhedral
34+
} // namespace tc

tc/core/polyhedral/body.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/**
2+
* Copyright (c) 2018, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <iostream>
19+
20+
#include "tc/external/isl.h"
21+
22+
namespace tc {
23+
namespace polyhedral {
24+
25+
// Information about the bodies of the polyhedral statements.
26+
struct Body {
27+
Body() = default;
28+
Body(isl::space paramSpace) {
29+
writes = reads = isl::union_map::empty(paramSpace);
30+
}
31+
32+
// Specialize to the given context.
33+
void specialize(isl::set context) {
34+
reads = reads.intersect_params(context);
35+
writes = writes.intersect_params(context);
36+
}
37+
38+
// Union maps describing the reads and writes done. Uses the ids in
39+
// the schedule tree to denote the containing Stmt, and tags each
40+
// access with a unique reference id of the form __tc_ref_N.
41+
isl::union_map reads, writes;
42+
};
43+
44+
std::ostream& operator<<(std::ostream& os, const Body& body);
45+
46+
} // namespace polyhedral
47+
} // namespace tc

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "tc/core/check.h"
2424
#include "tc/core/cuda/cuda_libraries.h"
2525
#include "tc/core/flags.h"
26+
#include "tc/core/polyhedral/body.h"
2627
#include "tc/core/polyhedral/codegen.h"
2728
#include "tc/core/polyhedral/cuda/codegen.h"
2829
#include "tc/core/polyhedral/cuda/mapping_types.h"
@@ -769,8 +770,8 @@ std::unordered_set<isl::id, isl::IslIdIslHash> gatherReadOnlySet(
769770

770771
const auto& scop = mscop.scop();
771772

772-
auto read = scop.reads.universe().range();
773-
auto written = scop.writes.universe().range();
773+
auto read = scop.body.reads.universe().range();
774+
auto written = scop.body.writes.universe().range();
774775
auto readOnly = read.subtract(written);
775776
for (auto s : readOnly.get_set_list()) {
776777
readOnlySet.emplace(s.get_tuple_id());

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ void promoteToSharedGreedy(
523523
auto partialSched = partialSchedule(root, bandNode);
524524

525525
auto groupMap = TensorReferenceGroup::accessedWithin(
526-
partialSched.intersect_domain(activePoints), scop.reads, scop.writes);
526+
partialSched.intersect_domain(activePoints), scop.body);
527527
// Pure affine schedule without (mapping) filters.
528528
auto partialSchedMupa = partialScheduleMupa(root, bandNode);
529529

@@ -650,7 +650,7 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
650650
collectMappingsTo<mapping::ThreadId>(scop).intersect(blockMapping);
651651
auto schedule = partialSchedule(scop.scheduleRoot(), scope);
652652
auto groupMap = TensorReferenceGroup::accessedWithin(
653-
schedule.intersect_domain(mapping), scop.reads, scop.writes);
653+
schedule.intersect_domain(mapping), scop.body);
654654

655655
auto threadSchedule = mscop.threadMappingSchedule(mscop.schedule());
656656
auto blockSchedule = mscop.blockMappingSchedule(mscop.schedule());

tc/core/polyhedral/memory_promotion.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <unordered_map>
2222

2323
#include "tc/core/check.h"
24+
#include "tc/core/polyhedral/body.h"
2425
#include "tc/core/polyhedral/exceptions.h"
2526
#include "tc/core/polyhedral/schedule_tree.h"
2627
#include "tc/core/polyhedral/scop.h"
@@ -341,15 +342,14 @@ void addSingletonReferenceGroups(
341342
// references.
342343
TensorGroups TensorReferenceGroup::accessedWithin(
343344
isl::union_map outerSchedule,
344-
isl::union_map reads,
345-
isl::union_map writes) {
345+
const Body& body) {
346346
TensorGroups tensorGroups;
347347
auto domain = outerSchedule.domain();
348348

349349
addSingletonReferenceGroups(
350-
tensorGroups, writes, domain, outerSchedule, AccessType::Write);
350+
tensorGroups, body.writes, domain, outerSchedule, AccessType::Write);
351351
addSingletonReferenceGroups(
352-
tensorGroups, reads, domain, outerSchedule, AccessType::Read);
352+
tensorGroups, body.reads, domain, outerSchedule, AccessType::Read);
353353

354354
// For each tensor, join groups whose footprints overlap and at least one
355355
// access is a write. Do not join between tensors because no aliasing.

tc/core/polyhedral/memory_promotion.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ class TensorReferenceGroup {
113113
public:
114114
static TensorGroups accessedWithin(
115115
isl::union_map outerSchedule,
116-
isl::union_map reads,
117-
isl::union_map writes);
116+
const Body& body);
118117

119118
bool isReadOnly() const;
120119

tc/core/polyhedral/scop.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include "tc/core/check.h"
2727
#include "tc/core/halide2isl.h"
28+
#include "tc/core/polyhedral/body.h"
2829
#include "tc/core/polyhedral/functional.h"
2930
#include "tc/core/polyhedral/memory_promotion.h"
3031
#include "tc/core/polyhedral/schedule_isl_conversion.h"
@@ -60,8 +61,7 @@ ScopUPtr Scop::makeScop(
6061

6162
auto tree = halide2isl::makeScheduleTree(paramSpace, components.stmt);
6263
scop->scheduleTreeUPtr = std::move(tree.tree);
63-
scop->reads = tree.reads;
64-
scop->writes = tree.writes;
64+
scop->body = tree.body;
6565
scop->halide.statements = std::move(tree.statements);
6666
scop->halide.accesses = std::move(tree.accesses);
6767
scop->halide.reductions = halide2isl::findReductions(components.stmt);
@@ -97,8 +97,7 @@ const isl::union_set Scop::domain() const {
9797

9898
std::ostream& operator<<(std::ostream& os, const Scop& s) {
9999
os << "domain: " << s.domain() << "\n";
100-
os << "reads: " << s.reads << "\n";
101-
os << "writes: " << s.writes << "\n";
100+
os << s.body;
102101
os << "schedule: " << *s.scheduleRoot() << "\n";
103102
os << "idx: { ";
104103
for (auto i : s.halide.idx) {
@@ -256,7 +255,7 @@ void Scop::promoteEverythingAt(std::vector<size_t> pos) {
256255
checkFiltersDisjointStatements(scheduleRoot());
257256
auto schedule = partialSchedule(root, tree);
258257

259-
auto groupMap = TensorReferenceGroup::accessedWithin(schedule, reads, writes);
258+
auto groupMap = TensorReferenceGroup::accessedWithin(schedule, body);
260259
for (auto& p : groupMap) {
261260
for (auto& gr : p.second) {
262261
promoteGroup(
@@ -353,8 +352,8 @@ isl::schedule_constraints makeScheduleConstraints(
353352

354353
void Scop::computeAllDependences() {
355354
auto schedule = toIslSchedule(scheduleRoot());
356-
auto allReads = reads.domain_factor_domain();
357-
auto allWrites = writes.domain_factor_domain();
355+
auto allReads = body.reads.domain_factor_domain();
356+
auto allWrites = body.writes.domain_factor_domain();
358357
// RAW
359358
auto flowDeps = computeDependences(allWrites, allReads, schedule);
360359
// WAR and WAW

0 commit comments

Comments
 (0)