Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit b93d260

Browse files
Merge pull request #522 from facebookresearch/pr/clean_up
assorted clean-ups
2 parents 41279c8 + c4082d6 commit b93d260

File tree

6 files changed

+13
-17
lines changed

6 files changed

+13
-17
lines changed

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -528,11 +528,10 @@ isl::multi_aff makeMultiAffAccess(
528528

529529
namespace {
530530
bool is_identifier_or_nonnegative_integer(isl::ast_expr expr) {
531-
if (isl_ast_expr_get_type(expr.get()) == isl_ast_expr_id)
532-
return true;
533-
if (isl_ast_expr_get_type(expr.get()) != isl_ast_expr_int)
534-
return false;
535-
return isl::manage(isl_ast_expr_get_val(expr.get())).is_nonneg();
531+
if (auto intExpr = expr.as<isl::ast_expr_int>()) {
532+
return intExpr.get_val().is_nonneg();
533+
}
534+
return !expr.as<isl::ast_expr_id>().is_null();
536535
}
537536
} // namespace
538537

tc/core/polyhedral/memory_promotion.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ std::unique_ptr<TensorReferenceGroup> TensorReferenceGroup::makeSingleton(
115115
return group;
116116
}
117117

118-
isl::map TensorReferenceGroup::approximateFootprint() const {
118+
isl::map TensorReferenceGroup::approximateScopedAccesses() const {
119119
auto scopedDomain = scopedAccesses().domain();
120120
auto space = approximation.box.get_space();
121121
auto accessed = isl::map::universe(space).intersect_domain(scopedDomain);
@@ -271,8 +271,8 @@ void joinOverlappingWrites(
271271
if (g1->isReadOnly() && g2->isReadOnly()) {
272272
continue;
273273
}
274-
if (g1->approximateFootprint()
275-
.intersect(g2->approximateFootprint())
274+
if (g1->approximateScopedAccesses()
275+
.intersect(g2->approximateScopedAccesses())
276276
.is_empty()) {
277277
continue;
278278
}
@@ -518,7 +518,7 @@ ScheduleTree* insertCopiesUnder(
518518
auto arrayId =
519519
promotionSpace.domain().unwrap().get_tuple_id(isl::dim_type::out);
520520
auto approximatedRead =
521-
group.approximateFootprint().intersect_range(tensorElements).wrap();
521+
group.approximateScopedAccesses().intersect_range(tensorElements).wrap();
522522
approximatedRead = approximatedRead.product(promotedFootprint);
523523
auto readExtension = extension.intersect_range(approximatedRead)
524524
.set_tuple_id(isl::dim_type::out, readId);

tc/core/polyhedral/memory_promotion.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class TensorReferenceGroup {
143143

144144
// Rectangular overapproximation of the set of tensor elements accessed below
145145
// and relative to the scoping point.
146-
isl::map approximateFootprint() const;
146+
isl::map approximateScopedAccesses() const;
147147

148148
isl::multi_aff promotion() const;
149149
isl::set promotedFootprint() const;

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ ScheduleTree* bandTile(
363363
return st;
364364
}
365365
auto& band = *eb;
366-
TC_CHECK(band.permutable_) << "Can't tile an non-permutable band" << band;
366+
TC_CHECK(band.permutable_) << "Can't tile a non-permutable band" << band;
367367

368368
auto ts = tileSizes;
369369
if (band.nMember() > ts.size()) {
@@ -381,9 +381,7 @@ ScheduleTree* bandTile(
381381
// Create a child, copy of st before outer tiling
382382
ScheduleTreeUPtr childUPtr = ScheduleTree::makeScheduleTree(*st);
383383

384-
for (size_t i = 0;
385-
i < std::min(static_cast<size_t>(band.nMember()), ts.size());
386-
++i) {
384+
for (size_t i = 0; i < band.nMember(); ++i) {
387385
auto upa = band.mupa_.get_union_pw_aff(i);
388386
if (ts[i]) {
389387
upa = upa.scale_down(isl::val(st->ctx_, ts[i])).floor();
@@ -399,7 +397,6 @@ ScheduleTree* bandTile(
399397
auto ebChild = childUPtr->elemAs<ScheduleTreeElemBand>();
400398
TC_CHECK(ebChild) << "Not a band: " << *childUPtr;
401399
auto& childBand = *ebChild;
402-
// No need for isl_schedule_band_point, it's almost done
403400
if (tileOptions & TileOptions::ShiftPointLoops) {
404401
auto mupa = band.mupa_;
405402
if (!(tileOptions & TileOptions::ScaleTileLoops)) {

tc/core/polyhedral/schedule_tree.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ bool ScheduleTree::operator==(const ScheduleTree& other) const {
361361
return false;
362362
}
363363
TC_CHECK(!other.elemAs<ScheduleTreeElemSet>())
364-
<< "NYI: isl_node_type::set comparison";
364+
<< "NYI: ScheduleTreeType::Set comparison";
365365
for (size_t i = 0; i < children_.size(); ++i) {
366366
if (*children_[i] != *other.children_[i]) {
367367
return false;

test/test_cuda_mapper_memory_promotion.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
305305
EXPECT_EQ(
306306
oneGroup->approximation.size(1),
307307
isl::val(ctx, std::min(tile2, problemSize2)));
308-
auto footprint = tileZero.apply(oneGroup->approximateFootprint());
308+
auto footprint = tileZero.apply(oneGroup->approximateScopedAccesses());
309309
size_t np = npoints(footprint);
310310
EXPECT_EQ(
311311
np, std::min(tile1, problemSize1) * std::min(tile2, problemSize2));

0 commit comments

Comments
 (0)