Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 9b7ecbb

Browse files
Merge pull request #527 from facebookresearch/pr/single_instance
ScheduleTree::bandTile: drop special casing 0D tile sizes
2 parents 031db14 + 7bb4101 commit 9b7ecbb

File tree

6 files changed

+34
-8
lines changed

6 files changed

+34
-8
lines changed

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,16 @@ class MappedScop {
7272
public:
7373
static inline std::unique_ptr<MappedScop> makeOneBlockOneThread(
7474
std::unique_ptr<Scop>&& scop) {
75-
return std::unique_ptr<MappedScop>(new MappedScop(
75+
auto mscop = std::unique_ptr<MappedScop>(new MappedScop(
7676
std::move(scop), ::tc::Grid{1, 1, 1}, ::tc::Block{1, 1, 1}, 1, false));
77+
auto band = mscop->scop_->obtainOuterBand();
78+
mscop->mapBlocksForward(band, 0);
79+
mscop->mapThreadsBackward(band);
80+
return mscop;
7781
}
82+
// The MappedScop returned by this method does not satisfy the invariant
83+
// of having a mapping to blocks and threads. It is up to the caller
84+
// to insert these mappings.
7885
static inline std::unique_ptr<MappedScop> makeMappedScop(
7986
std::unique_ptr<Scop>&& scop,
8087
::tc::Grid grid,

tc/core/polyhedral/cuda/tighten_launch_bounds.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ size_t maxValue(const Scop& scop, const MappingIdType& id) {
8888
LOG_IF(WARNING, min > 0)
8989
<< "Opportunity for tightening launch bounds with shifting -> min:"
9090
<< min;
91-
TC_CHECK(max < sizetMax) << "missing mapping to " << id << *root;
91+
TC_CHECK(max < sizetMax) << "missing mapping to " << id << "\n" << *root;
92+
TC_CHECK(min < sizetMax) << "missing mapping to " << id << " type\n" << *root;
9293
// Inclusive range needs + 1 to translate to sizes
9394
return max + 1;
9495
}

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,6 @@ ScheduleTree* bandTile(
359359
auto eb = st->elemAs<ScheduleTreeElemBand>();
360360
TC_CHECK(eb) << "Not a band: " << *st;
361361

362-
if (tileSizes.size() == 0) {
363-
return st;
364-
}
365362
auto& band = *eb;
366363
TC_CHECK(band.permutable_) << "Can't tile a non-permutable band" << band;
367364

tc/core/polyhedral/scop.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,8 @@ detail::ScheduleTree* setPermutable(detail::ScheduleTree* tree) {
423423
return tree;
424424
}
425425

426+
} // namespace
427+
426428
/*
427429
* Return the outermost band in the schedule tree with the given root.
428430
* If there is no single outermost band, then insert a (permutable)
@@ -431,7 +433,8 @@ detail::ScheduleTree* setPermutable(detail::ScheduleTree* tree) {
431433
* insert the band in the leaf. If branching is encountered, then
432434
* insert the band above the branching.
433435
*/
434-
detail::ScheduleTree* obtainOuterBand(detail::ScheduleTree* root) {
436+
detail::ScheduleTree* Scop::obtainOuterBand() {
437+
auto root = scheduleRoot();
435438
auto tree = root;
436439
while (!tree->elemAs<ScheduleTreeElemBand>()) {
437440
auto n = tree->numChildren();
@@ -449,11 +452,10 @@ detail::ScheduleTree* obtainOuterBand(detail::ScheduleTree* root) {
449452
}
450453
return tree;
451454
}
452-
} // namespace
453455

454456
detail::ScheduleTree* Scop::tileOuterBand(const TilingView& tileSizes) {
455457
using namespace tc::polyhedral::detail;
456-
auto band = obtainOuterBand(scheduleRoot());
458+
auto band = obtainOuterBand();
457459
auto bandNode = band->elemAs<ScheduleTreeElemBand>();
458460
std::vector<size_t> sizes = tileSizes.extractVector();
459461
if (bandNode->nMember() < sizes.size()) {

tc/core/polyhedral/scop.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,10 @@ struct Scop {
394394
static std::unique_ptr<Scop> makeScheduled(
395395
const Scop& scop,
396396
const SchedulerOptionsView& schedulerOptions);
397+
// Return the outermost band in the schedule tree with the given root.
398+
// If there is no single outermost band, then insert a (permutable)
399+
// zero-dimensional band and return that.
400+
detail::ScheduleTree* obtainOuterBand();
397401
// Tile the outermost band.
398402
// Splits the band into tile loop band and point loop band where point loops
399403
// have fixed trip counts specified in "tiling", and returns a pointer to the

test/test_cuda_mapper.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,21 @@ def local_sparse_convolution(float(N, C, H, W) I, float(O, KC, KH, KW) W1) -> (O
11511151
}
11521152
}
11531153

1154+
/*
1155+
* Check that a TC with a single instance gets mapped properly.
1156+
* tightenLaunchBounds (called during codegen) will complain
1157+
* if it is not.
1158+
*/
1159+
TEST_F(PolyhedralMapperTest, SingleInstance) {
1160+
string tc = R"TC(
1161+
def f(float(N) I) -> (a)
1162+
{
1163+
a = 0
1164+
}
1165+
)TC";
1166+
codegenMapped(tc, DefaultOptions());
1167+
}
1168+
11541169
int main(int argc, char** argv) {
11551170
::testing::InitGoogleTest(&argc, argv);
11561171
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)