Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 1b30d6d

Browse files
committed
ScheduleTreeElem*: inherit from ScheduleTree instead of ScheduleTreeElemBase
The existing implementation of TC schedule trees, inherited from the prehistory, chose to have an abstract tree node type (ScheduleTree) with a unique pointer to a specific node element (ScheduleTreeElem*) as payload. It is unclear now why this structure was chosen, possible reasons include: the duality between isl_schedule_tree and isl_schedule_node in isl internals; gradual movement from raw pointers, to shared, to unique pointers; an attempt to separate tree memory management (unique pointers to children in ScheduleTree) from data (fields in ScheduleTreeElem*). However, the resulting structure ended up being unnecessarily complex. Conceptually, a specific node type _is-a_ tree node, so it makes sense to have specific nodes inherit from a generic one. isl C++ API already provides this structure and it proved to be more convenient. In particular, it enables functions to specify that they expect a certain type of node statically as a part of their signature rather than dynamically through CHECK macros. Make ScheduleTreeElem* classes inherit from ScheduleTree instead of ScheduleTreeElemBase and remove the "elem_" field from ScheduleTree (the tree class itself can now be casted to a subclass). This change resulted in changes to copy and static constructors of the ScheduleTree* classes. In particular, Sequence, Set and ThreadSpecificMarker elements now use the global isl context in construction since the ScheduleTree constructor requires an isl context. This will be cleaned up in an upcoming commit. Note that this commit does not rename elem-related functions for the sake of diff minimization. These changes will be performed in subsequent commits.
1 parent 2f60358 commit 1b30d6d

File tree

6 files changed

+110
-134
lines changed

6 files changed

+110
-134
lines changed

tc/core/polyhedral/schedule_isl_conversion.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,7 @@ std::unique_ptr<ScheduleTreeElemBand> fromIslScheduleNodeBand(
248248
return res;
249249
}
250250

251-
std::unique_ptr<ScheduleTreeElemBase> elemFromIslScheduleNode(
252-
isl::schedule_node node) {
251+
std::unique_ptr<ScheduleTree> elemFromIslScheduleNode(isl::schedule_node node) {
253252
if (auto band = node.as<isl::schedule_node_band>()) {
254253
return fromIslScheduleNodeBand(band);
255254
} else if (auto context = node.as<isl::schedule_node_context>()) {
@@ -278,15 +277,15 @@ std::unique_ptr<ScheduleTreeElemBase> elemFromIslScheduleNode(
278277
LOG(FATAL) << "mark nodes not supported";
279278
return nullptr;
280279
} else if (node.isa<isl::schedule_node_leaf>()) {
281-
LOG(FATAL) << "ScheduleTreeElemBase::make called on explicit leaf";
280+
LOG(FATAL) << "ScheduleTree::make called on explicit leaf";
282281
return nullptr;
283282
} else if (node.isa<isl::schedule_node_sequence>()) {
284283
return std::unique_ptr<ScheduleTreeElemSequence>(
285284
new ScheduleTreeElemSequence());
286285
} else if (node.isa<isl::schedule_node_set>()) {
287286
return std::unique_ptr<ScheduleTreeElemSet>(new ScheduleTreeElemSet());
288287
}
289-
LOG(FATAL) << "NYI: ScheduleTreeElemBase from type: "
288+
LOG(FATAL) << "NYI: ScheduleTree from type: "
290289
<< isl_schedule_node_get_type(node.get());
291290
return nullptr;
292291
}
@@ -299,9 +298,7 @@ std::unique_ptr<ScheduleTreeElemBase> elemFromIslScheduleNode(
299298
* if this single child node is a leaf.
300299
*/
301300
std::unique_ptr<ScheduleTree> fromIslScheduleNode(isl::schedule_node node) {
302-
unique_ptr<ScheduleTree> res(new ScheduleTree(node.get_ctx()));
303-
res->elem_ = elemFromIslScheduleNode(node);
304-
res->type_ = res->elem_->type();
301+
auto res = elemFromIslScheduleNode(node);
305302
auto n = node.n_children();
306303
if (n == 1 && node.child(0).isa<isl::schedule_node_leaf>()) {
307304
return res;

tc/core/polyhedral/schedule_print.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,6 @@ std::ostream& operator<<(std::ostream& os, detail::ScheduleTreeType nt) {
137137
return os;
138138
}
139139

140-
std::ostream& operator<<(std::ostream& os, const ScheduleTreeElemBase& st) {
141-
st.write(os);
142-
return os;
143-
}
144-
145140
std::ostream& ScheduleTreeElemBand::write(std::ostream& os) const {
146141
WS w;
147142
os << w.tab() << "band(n(" << coincident_.size() << ") permutable(";
@@ -253,8 +248,8 @@ std::ostream& operator<<(
253248
}
254249

255250
std::ostream& operator<<(std::ostream& os, const ScheduleTree& st) {
256-
TC_CHECK(st.elem_.get());
257-
os << *st.elem_ << "\n";
251+
st.write(os);
252+
os << "\n";
258253
os << st.children_;
259254

260255
return os;

tc/core/polyhedral/schedule_tree.cc

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,10 @@ vector<ScheduleTree*> ancestorsInSubTree(
123123
return res;
124124
}
125125

126-
static std::unique_ptr<ScheduleTreeElemBase> makeElem(const ScheduleTree& st) {
127-
#define ELEM_MAKE_CASE(CLASS) \
128-
else if (st.type_ == CLASS::NodeType) { \
129-
return std::unique_ptr<CLASS>( \
130-
new CLASS(*static_cast<CLASS*>(st.elem_.get()))); \
126+
static std::unique_ptr<ScheduleTree> makeElem(const ScheduleTree& st) {
127+
#define ELEM_MAKE_CASE(CLASS) \
128+
else if (st.type_ == CLASS::NodeType) { \
129+
return std::unique_ptr<CLASS>(new CLASS(static_cast<const CLASS&>(st))); \
131130
}
132131

133132
if (st.type_ == detail::ScheduleTreeType::None) {
@@ -145,8 +144,7 @@ static std::unique_ptr<ScheduleTreeElemBase> makeElem(const ScheduleTree& st) {
145144

146145
#undef ELEM_MAKE_CASE
147146

148-
LOG(FATAL) << "NYI: ScheduleTreeElemBase from type: "
149-
<< static_cast<int>(st.type_);
147+
LOG(FATAL) << "NYI: ScheduleTree from type: " << static_cast<int>(st.type_);
150148
return nullptr;
151149
}
152150
} // namespace
@@ -159,17 +157,15 @@ ScheduleTree::ScheduleTree(isl::ctx ctx) : ctx_(ctx) {}
159157
ScheduleTree::~ScheduleTree() {}
160158

161159
ScheduleTree::ScheduleTree(const ScheduleTree& st)
162-
: ctx_(st.ctx_), children_(), type_(st.type_), elem_(nullptr) {
160+
: ctx_(st.ctx_), children_(), type_(st.type_) {
163161
children_.reserve(st.children_.size());
164162
for (const auto& child : st.children()) {
165163
children_.push_back(ScheduleTree::makeScheduleTree(*child));
166164
}
167165
}
168166

169167
ScheduleTreeUPtr ScheduleTree::makeScheduleTree(const ScheduleTree& tree) {
170-
auto res = ScheduleTreeUPtr(new ScheduleTree(tree));
171-
res->elem_ = makeElem(tree);
172-
return res;
168+
return makeElem(tree);
173169
}
174170

175171
ScheduleTree* ScheduleTree::child(const vector<size_t>& positions) {
@@ -232,9 +228,7 @@ size_t ScheduleTree::scheduleDepth(const ScheduleTree* relativeRoot) const {
232228
std::unique_ptr<ScheduleTree> ScheduleTree::makeBand(
233229
isl::multi_union_pw_aff mupa,
234230
std::vector<ScheduleTreeUPtr>&& children) {
235-
isl::ctx ctx = mupa.get_ctx();
236-
std::unique_ptr<ScheduleTree> res(new ScheduleTree(ctx));
237-
res->elem_ = ScheduleTreeElemBand::fromMultiUnionPwAff(mupa);
231+
auto res = ScheduleTreeElemBand::fromMultiUnionPwAff(mupa);
238232
res->type_ = detail::ScheduleTreeType::Band;
239233
res->appendChildren(std::move(children));
240234
return res;
@@ -252,10 +246,7 @@ ScheduleTreeUPtr ScheduleTree::makeEmptyBand(const ScheduleTree* root) {
252246
std::unique_ptr<ScheduleTree> ScheduleTree::makeDomain(
253247
isl::union_set domain,
254248
std::vector<ScheduleTreeUPtr>&& children) {
255-
isl::ctx ctx(domain.get_ctx());
256-
std::unique_ptr<ScheduleTree> res(new ScheduleTree(ctx));
257-
res->elem_ = std::unique_ptr<ScheduleTreeElemDomain>(
258-
new ScheduleTreeElemDomain(domain));
249+
auto res = std::unique_ptr<ScheduleTree>(new ScheduleTreeElemDomain(domain));
259250
res->type_ = detail::ScheduleTreeType::Domain;
260251
res->appendChildren(std::move(children));
261252
return res;
@@ -264,10 +255,8 @@ std::unique_ptr<ScheduleTree> ScheduleTree::makeDomain(
264255
std::unique_ptr<ScheduleTree> ScheduleTree::makeContext(
265256
isl::set context,
266257
std::vector<ScheduleTreeUPtr>&& children) {
267-
isl::ctx ctx(context.get_ctx());
268-
std::unique_ptr<ScheduleTree> res(new ScheduleTree(ctx));
269-
res->elem_ = std::unique_ptr<ScheduleTreeElemContext>(
270-
new ScheduleTreeElemContext(context));
258+
auto res =
259+
std::unique_ptr<ScheduleTree>(new ScheduleTreeElemContext(context));
271260
res->type_ = detail::ScheduleTreeType::Context;
272261
res->appendChildren(std::move(children));
273262
return res;
@@ -276,10 +265,7 @@ std::unique_ptr<ScheduleTree> ScheduleTree::makeContext(
276265
std::unique_ptr<ScheduleTree> ScheduleTree::makeFilter(
277266
isl::union_set filter,
278267
std::vector<ScheduleTreeUPtr>&& children) {
279-
isl::ctx ctx(filter.get_ctx());
280-
std::unique_ptr<ScheduleTree> res(new ScheduleTree(ctx));
281-
res->elem_ = std::unique_ptr<ScheduleTreeElemFilter>(
282-
new ScheduleTreeElemFilter(filter));
268+
auto res = std::unique_ptr<ScheduleTree>(new ScheduleTreeElemFilter(filter));
283269
res->type_ = detail::ScheduleTreeType::Filter;
284270
res->appendChildren(std::move(children));
285271
return res;
@@ -299,22 +285,17 @@ std::unique_ptr<ScheduleTree> ScheduleTree::makeMappingUnsafe(
299285
TC_CHECK_EQ(mappedIds.size(), mapping.size())
300286
<< "some id is used more than once in the mapping";
301287
auto ctx = mappedIds[0].get_ctx();
302-
ScheduleTreeUPtr res(new ScheduleTree(
303-
ctx,
304-
std::move(children),
305-
ScheduleTreeType::Mapping,
306-
std::unique_ptr<ScheduleTreeElemMapping>(
307-
new ScheduleTreeElemMapping(mapping))));
288+
auto res =
289+
std::unique_ptr<ScheduleTree>(new ScheduleTreeElemMapping(mapping));
290+
res->appendChildren(std::move(children));
308291
return res;
309292
}
310293

311294
std::unique_ptr<ScheduleTree> ScheduleTree::makeExtension(
312295
isl::union_map extension,
313296
std::vector<ScheduleTreeUPtr>&& children) {
314-
isl::ctx ctx(extension.get_ctx());
315-
ScheduleTreeUPtr res(new ScheduleTree(ctx));
316-
res->elem_ = std::unique_ptr<ScheduleTreeElemExtension>(
317-
new ScheduleTreeElemExtension(extension));
297+
auto res =
298+
std::unique_ptr<ScheduleTree>(new ScheduleTreeElemExtension(extension));
318299
res->type_ = detail::ScheduleTreeType::Extension;
319300
res->appendChildren(std::move(children));
320301
return res;
@@ -323,9 +304,8 @@ std::unique_ptr<ScheduleTree> ScheduleTree::makeExtension(
323304
std::unique_ptr<ScheduleTree> ScheduleTree::makeThreadSpecificMarker(
324305
isl::ctx ctx,
325306
std::vector<ScheduleTreeUPtr>&& children) {
326-
ScheduleTreeUPtr res(new ScheduleTree(ctx));
327-
res->elem_ = std::unique_ptr<ScheduleTreeElemThreadSpecificMarker>(
328-
new ScheduleTreeElemThreadSpecificMarker());
307+
auto res =
308+
std::unique_ptr<ScheduleTree>(new ScheduleTreeElemThreadSpecificMarker());
329309
res->type_ = detail::ScheduleTreeType::ThreadSpecificMarker;
330310
res->appendChildren(std::move(children));
331311
return res;
@@ -411,7 +391,7 @@ bool ScheduleTree::operator==(const ScheduleTree& other) const {
411391
if (children_.size() != other.children_.size()) {
412392
return false;
413393
}
414-
if (!elemEquals(elem_.get(), other.elem_.get(), type_)) {
394+
if (!elemEquals(this, &other, type_)) {
415395
return false;
416396
}
417397
TC_CHECK(!other.elemAs<ScheduleTreeElemSet>())

tc/core/polyhedral/schedule_tree.h

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
#include <vector>
2222

2323
#include "tc/core/check.h"
24+
#include "tc/core/polyhedral/mapping_types.h"
2425
#include "tc/core/polyhedral/options.h"
25-
#include "tc/core/polyhedral/schedule_tree_elem.h"
2626
#include "tc/core/utils/vararg.h"
2727
#include "tc/external/isl.h"
2828

@@ -36,6 +36,23 @@ namespace detail {
3636
// ScheduleTree, convertible to and from isl::schedule.
3737
//
3838
struct ScheduleTree;
39+
struct ScheduleTreeElemSet;
40+
struct ScheduleTreeElemSequence;
41+
struct ScheduleTreeElemMapping;
42+
43+
enum class ScheduleTreeType {
44+
None,
45+
Band,
46+
Context,
47+
Domain,
48+
Extension,
49+
Filter,
50+
Sequence,
51+
Set,
52+
Mapping,
53+
ThreadSpecificMarker,
54+
Any,
55+
};
3956

4057
} // namespace detail
4158

@@ -116,19 +133,20 @@ struct ScheduleTree {
116133

117134
private:
118135
ScheduleTree() = delete;
136+
137+
protected:
119138
ScheduleTree(
120139
isl::ctx ctx,
121140
std::vector<ScheduleTreeUPtr>&& children,
122-
detail::ScheduleTreeType type,
123-
std::unique_ptr<ScheduleTreeElemBase>&& elem)
124-
: ctx_(ctx), type_(type), elem_(std::move(elem)) {
141+
detail::ScheduleTreeType type)
142+
: ctx_(ctx), type_(type) {
125143
appendChildren(std::move(children));
126144
}
127145

128146
// Copy constructor for internal use only.
129-
// Note that this does not copy the underlying elem_.
147+
// Note that this does not account for a specific subclass of ScheduleTree.
130148
// All callers should use makeScheduleTree(const ScheduleTree&) instead,
131-
// which copies the underlying elem_ as well as keeps the memory
149+
// which dispatches the copying to subclasses as well as keeps the memory
132150
// management scheme consistent.
133151
ScheduleTree(const ScheduleTree& st);
134152

@@ -394,23 +412,19 @@ struct ScheduleTree {
394412
static ScheduleTreeUPtr
395413
fromList(detail::ScheduleTreeType type, Arg&& arg, Args&&... args) {
396414
static_assert(
397-
std::is_base_of<ScheduleTreeElemBase, T>::value,
398-
"Can only construct elements derived from ScheduleTreeElemBase");
415+
std::is_base_of<ScheduleTree, T>::value,
416+
"Can only construct elements derived from ScheduleTree");
399417
static_assert(
400418
std::is_same<
401419
typename std::remove_reference<Arg>::type,
402420
ScheduleTreeUPtr>::value,
403421
"Arguments must be rvalue references to ScheduleTreeUPtr");
404422

405-
auto ctx = arg->ctx_;
406-
std::vector<ScheduleTreeUPtr> children =
407-
vectorFromArgs(std::forward<Arg>(arg), std::forward<Args>(args)...);
408-
409-
auto res = ScheduleTreeUPtr(new ScheduleTree(
410-
ctx,
411-
std::move(children),
412-
type,
413-
std::unique_ptr<ScheduleTreeElemBase>(new T)));
423+
auto ctx = arg->ctx_; // FIXME: pass this to the constructor of T
424+
// when possible
425+
auto res = ScheduleTreeUPtr(new T);
426+
res->appendChildren(
427+
vectorFromArgs(std::forward<Arg>(arg), std::forward<Args>(args)...));
414428

415429
if (type == detail::ScheduleTreeType::Sequence ||
416430
type == detail::ScheduleTreeType::Set) {
@@ -453,7 +467,7 @@ struct ScheduleTree {
453467
const ScheduleTree* tree,
454468
detail::ScheduleTreeType type);
455469

456-
// View elem_ as the specified type.
470+
// View this tree node as the specified type.
457471
// Returns nullptr if this is not the proper type.
458472
// Inline impl for now, does not justify an extra -inl.h file
459473
template <typename T>
@@ -464,13 +478,12 @@ struct ScheduleTree {
464478
template <typename T>
465479
const T* elemAs() const {
466480
static_assert(
467-
std::is_base_of<ScheduleTreeElemBase, T>::value,
468-
"Must call with a class derived from ScheduleTreeElemBase");
481+
std::is_base_of<ScheduleTree, T>::value,
482+
"Must call with a class derived from ScheduleTree");
469483
if (type_ != T::NodeType) {
470484
return nullptr;
471485
}
472-
return static_cast<const T*>(
473-
const_cast<const ScheduleTreeElemBase*>(elem_.get()));
486+
return static_cast<const T*>(this);
474487
}
475488

476489
virtual ScheduleTreeType type() const {
@@ -494,7 +507,6 @@ struct ScheduleTree {
494507

495508
public:
496509
detail::ScheduleTreeType type_{detail::ScheduleTreeType::None};
497-
std::unique_ptr<ScheduleTreeElemBase> elem_{nullptr};
498510
};
499511

500512
} // namespace detail

tc/core/polyhedral/schedule_tree_elem.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ namespace detail {
4040
std::unique_ptr<ScheduleTreeElemBand> ScheduleTreeElemBand::fromMultiUnionPwAff(
4141
isl::multi_union_pw_aff mupa) {
4242
isl::ctx ctx(mupa.get_ctx());
43-
std::unique_ptr<ScheduleTreeElemBand> band(new ScheduleTreeElemBand);
43+
std::unique_ptr<ScheduleTreeElemBand> band(new ScheduleTreeElemBand(ctx));
4444
band->mupa_ = mupa.floor();
4545
size_t n = band->mupa_.size();
4646
band->coincident_ = vector<bool>(n, false);
@@ -202,8 +202,8 @@ bool ScheduleTreeElemSet::operator==(const ScheduleTreeElemSet& other) const {
202202
}
203203

204204
bool elemEquals(
205-
const ScheduleTreeElemBase* e1,
206-
const ScheduleTreeElemBase* e2,
205+
const ScheduleTree* e1,
206+
const ScheduleTree* e2,
207207
detail::ScheduleTreeType type) {
208208
#define ELEM_EQUALS_CASE(CLASS) \
209209
else if (type == CLASS::NodeType) { \
@@ -222,7 +222,7 @@ bool elemEquals(
222222
ELEM_EQUALS_CASE(ScheduleTreeElemSequence)
223223
ELEM_EQUALS_CASE(ScheduleTreeElemSet)
224224
else {
225-
LOG(FATAL) << "NYI: ScheduleTreeElemBase::operator== for type: "
225+
LOG(FATAL) << "NYI: ScheduleTree::operator== for type: "
226226
<< static_cast<int>(type);
227227
}
228228

0 commit comments

Comments
 (0)