From 5df21f85fe0c431efb2b31a9f11ae3e2dbf668f1 Mon Sep 17 00:00:00 2001 From: Sven Verdoolaege Date: Thu, 17 May 2018 11:56:45 +0200 Subject: [PATCH 1/4] MappedScop::findBestSync: avoid redundant computations isl does not currently perform any operations lazily, so a potentially large product map would get created even if only a very small part of it will actually get used. Intersect domain and range directly instead. --- tc/core/polyhedral/cuda/mapped_scop.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 955190ade..ebcbac989 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -550,9 +550,9 @@ Scop::SyncLevel MappedScop::findBestSync( auto activePoints2 = activeDomainPointsBelow(stRoot, st2); // The dependences between the two schedule trees - auto dependences = - isl::union_map::from_domain_and_range(activePoints1, activePoints2); - dependences = dependences.intersect(scop_->dependences); + auto dependences = scop_->dependences; + dependences = dependences.intersect_domain(activePoints1); + dependences = dependences.intersect_range(activePoints2); if (dependences.is_empty()) { return Scop::SyncLevel::None; } From a630d0456891446eff3bf17bc1e27d6e868b32ee Mon Sep 17 00:00:00 2001 From: Sven Verdoolaege Date: Wed, 16 May 2018 17:44:08 +0200 Subject: [PATCH 2/4] ScheduleTreeElemMappingFilter: store mapping between identifiers and functions Store this mapping in addition to the filter derived from the mapping. The mapping is needed in the next commit. The mapping is constructed in ScheduleTree::makeMappingFilter and passed to the ScheduleTreeElemMappingFilter constructor. Some of the sanity checks are therefore moved from the constructor to ScheduleTree::makeMappingFilter. --- .../cuda/memory_promotion_heuristic.cc | 4 +- tc/core/polyhedral/schedule_print.cc | 4 +- tc/core/polyhedral/schedule_tree-inl.h | 14 ++++--- tc/core/polyhedral/schedule_tree_elem.h | 40 ++++++++----------- 4 files changed, 30 insertions(+), 32 deletions(-) diff --git a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc index c51e09a96..1c174e0ae 100644 --- a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc +++ b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc @@ -46,8 +46,8 @@ bool isThreadMapping(const detail::ScheduleTree* tree) { using namespace detail; if (auto filterNode = tree->elemAs()) { - for (auto id : filterNode->mappingIds) { - if (isThreadId(id)) { + for (auto& kvp : filterNode->mapping) { + if (isThreadId(kvp.first)) { return true; } } diff --git a/tc/core/polyhedral/schedule_print.cc b/tc/core/polyhedral/schedule_print.cc index b1371649f..d12b43541 100644 --- a/tc/core/polyhedral/schedule_print.cc +++ b/tc/core/polyhedral/schedule_print.cc @@ -202,8 +202,8 @@ std::ostream& ScheduleTreeElemFilter::write(std::ostream& os) const { std::ostream& ScheduleTreeElemMappingFilter::write(std::ostream& os) const { WS w; os << w.tab() << "mapping_filter(ids("; - for (const auto& id : mappingIds) { - os << id << ", "; + for (auto& kvp : mapping) { + os << kvp.first << ", "; } os << ")"; for (const auto& u : filter_.get_set_list()) { diff --git a/tc/core/polyhedral/schedule_tree-inl.h b/tc/core/polyhedral/schedule_tree-inl.h index b97a2ff1d..f3e4e614d 100644 --- a/tc/core/polyhedral/schedule_tree-inl.h +++ b/tc/core/polyhedral/schedule_tree-inl.h @@ -23,15 +23,19 @@ inline ScheduleTreeUPtr ScheduleTree::makeMappingFilter( const std::vector& mappedIds, isl::union_pw_aff_list mappedAffs, std::vector&& children) { - std::vector ids; - for (auto id : mappedIds) { - ids.push_back(id); + CHECK_EQ(mappedIds.size(), static_cast(mappedAffs.n())) + << "expected as many mapped ids as affs"; + ScheduleTreeElemMappingFilter::Mapping mapping; + for (size_t i = 0, n = mappedAffs.n(); i < n; ++i) { + mapping.emplace(mappedIds.at(i), mappedAffs.get(i)); } - CHECK_GE(ids.size(), 1u) << "empty mapping"; + CHECK_GE(mapping.size(), 1u) << "empty mapping"; + CHECK_EQ(mappedIds.size(), mapping.size()) + << "some id is used more than once in the mapping"; auto ctx = mappedIds[0].get_ctx(); ScheduleTreeUPtr res(new ScheduleTree(ctx)); res->elem_ = std::unique_ptr( - new ScheduleTreeElemMappingFilter(ids, mappedAffs)); + new ScheduleTreeElemMappingFilter(mapping)); res->type_ = ScheduleTreeType::MappingFilter; res->appendChildren(std::move(children)); return res; diff --git a/tc/core/polyhedral/schedule_tree_elem.h b/tc/core/polyhedral/schedule_tree_elem.h index 2f0e2921f..f1ab3153b 100644 --- a/tc/core/polyhedral/schedule_tree_elem.h +++ b/tc/core/polyhedral/schedule_tree_elem.h @@ -17,7 +17,7 @@ #include #include -#include +#include #include #include "tc/external/isl.h" @@ -139,34 +139,29 @@ struct ScheduleTreeElemFilter : public ScheduleTreeElemBase { }; struct ScheduleTreeElemMappingFilter : public ScheduleTreeElemFilter { + using Mapping = std::unordered_map< + mapping::MappingId, + isl::union_pw_aff, + typename mapping::MappingId::Hash>; static constexpr std::initializer_list NodeDerivedTypes{detail::ScheduleTreeType::None}; static constexpr detail::ScheduleTreeType NodeType = detail::ScheduleTreeType::MappingFilter; ScheduleTreeElemMappingFilter() = delete; ScheduleTreeElemMappingFilter(const ScheduleTreeElemMappingFilter& eb) - : ScheduleTreeElemFilter(eb.filter_), mappingIds(eb.mappingIds) {} - ScheduleTreeElemMappingFilter( - const std::vector& mappedIds, - isl::union_pw_aff_list mappedAffs) - : ScheduleTreeElemFilter(isl::union_set()), - mappingIds(mappedIds.begin(), mappedIds.end()) { - // Check that ids are unique. - CHECK_EQ(mappedIds.size(), mappingIds.size()) - << "some id is used more than once in the mapping"; + : ScheduleTreeElemFilter(eb.filter_), mapping(eb.mapping) {} + ScheduleTreeElemMappingFilter(const Mapping& mapping) + : ScheduleTreeElemFilter(isl::union_set()), mapping(mapping) { + CHECK_GT(mapping.size(), 0u) << "empty mapping filter"; - CHECK_EQ(mappedIds.size(), static_cast(mappedAffs.n())) - << "expected as many mapped ids as affs"; - CHECK_GT(mappedIds.size(), 0u) << "empty mapping filter"; - - auto domain = mappedAffs.get(0).domain(); - for (size_t i = 1, n = mappedAffs.n(); i < n; ++i) { - CHECK(domain.is_equal(mappedAffs.get(i).domain())); + auto domain = mapping.cbegin()->second.domain(); + for (auto& kvp : mapping) { + CHECK(domain.is_equal(kvp.second.domain())); } filter_ = domain.universe(); - for (size_t i = 0, n = mappedAffs.n(); i < n; ++i) { - auto upa = mappedAffs.get(i); - auto id = mappedIds.at(i); + for (auto& kvp : mapping) { + auto upa = kvp.second; + auto id = kvp.first; // Create mapping filter by equating the // parameter mappedIds[i] to the "i"-th affine function. upa = upa.sub(isl::union_pw_aff::param_on_domain(domain.universe(), id)); @@ -183,9 +178,8 @@ struct ScheduleTreeElemMappingFilter : public ScheduleTreeElemFilter { return NodeType; } - const std:: - unordered_set - mappingIds; + // Mapping from identifiers to affine functions on domain elements. + const Mapping mapping; }; struct ScheduleTreeElemSequence : public ScheduleTreeElemBase { From e3323edbe1873ee4388eaee58c3fa5feb236d850 Mon Sep 17 00:00:00 2001 From: Sven Verdoolaege Date: Thu, 17 May 2018 12:07:15 +0200 Subject: [PATCH 3/4] MappedScop::findBestSync: use proper mapping to threads and warps In particular, use mappings to threads and warps to determine whether all dependences are within a thread or warp. Do so instead of considering pairs of thread identifier parameters. Duplicating thread identifier parameters is confusing at best and relies on renaming parameters, which is frowned upon and which will not be exported in the mainline isl C++ interface. Duplicate parameters are confusing because there is only one value for any given thread identifier at a given point in the execution. Furthermore, the filter assigning the thread mapping to parameters is only relevant during AST generation. Prior to that, it is more natural to think in terms of the thread mapping itself. This is what this commit does. It is not only simpler, but also shorter. It is not entirely clear why the intersection with the context is needed, but this is what the original code does and it is preserved by this commit. --- tc/core/polyhedral/cuda/mapped_scop.cc | 201 +++++++++---------------- tc/core/polyhedral/cuda/mapped_scop.h | 6 +- 2 files changed, 78 insertions(+), 129 deletions(-) diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index ebcbac989..11211c0e6 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -431,119 +431,84 @@ bool hasOuterSequentialMember( return false; } -// Intersect the union set with all the mapping -// filters params in the given schedule tree -isl::union_set intersectMappingFilterParams( - detail::ScheduleTree* st, - isl::union_set us) { - if (auto filter = st->elemAsBase()) { - us = us.intersect(filter->filter_); - } +// Name of the space of threads inside a block +constexpr auto kBlock = "block"; +// Name of the space of warps +constexpr auto kWarp = "warp"; - auto children = st->children(); - auto nChildren = children.size(); - if (nChildren == 1) { - us = intersectMappingFilterParams(children[0], us); - } else if (nChildren > 1) { - auto usParent = us; - us = intersectMappingFilterParams(children[0], us); - for (size_t i = 1; i < nChildren; ++i) { - us = us.unite(intersectMappingFilterParams(children[i], usParent)); - } - } - - return us; -} +/* + * Extract a mapping from the domain elements active at "tree" + * to the thread identifiers, where all branches in "tree" + * are assumed to have been mapped to thread identifiers. + * "nThread" is the number of thread identifiers. + * The result lives in a space of the form block[x, ...]. + */ +isl::multi_union_pw_aff extractDomainToThread( + const detail::ScheduleTree* tree, + size_t nThread) { + using namespace polyhedral::detail; -// Change the name of the isl ids tied to threads and blocks -// by adding a suffix -isl::union_set modifyMappingNames( - isl::union_set set, - const std::string suffix) { - USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ); - std::unordered_set identifiers{ - BX, BY, BZ, TX, TY, TZ}; - - auto space = set.get_space(); - for (auto id : identifiers) { - auto name = id.get_name(); - auto dim = space.find_dim_by_name(isl::dim_type::param, id.get_name()); - CHECK_LE(0, dim); - space = space.set_dim_name(isl::dim_type::param, dim, name + suffix); - } - auto newSet = isl::union_set::empty(space); - set.foreach_set([&newSet, &identifiers, &suffix](isl::set setInFun) { - for (auto id : identifiers) { - auto name = id.get_name(); - auto dim = - setInFun.get_space().find_dim_by_name(isl::dim_type::param, name); - CHECK_LE(0, dim); - setInFun = - setInFun.set_dim_name(isl::dim_type::param, dim, name + suffix); + auto space = isl::space(tree->ctx_, 0); + auto empty = isl::union_set::empty(space); + auto id = isl::id(tree->ctx_, kBlock); + space = space.named_set_from_params_id(id, nThread); + auto zero = isl::multi_val::zero(space); + auto domainToThread = isl::multi_union_pw_aff(empty, zero); + + for (auto mapping : tree->collect(tree, ScheduleTreeType::MappingFilter)) { + auto mappingNode = mapping->elemAs(); + auto list = isl::union_pw_aff_list(tree->ctx_, nThread); + for (size_t i = 0; i < nThread; ++i) { + auto threadId = mapping::ThreadId::makeId(i); + auto threadMap = mappingNode->mapping.at(threadId); + list = list.add(threadMap); } - newSet = newSet.unite(setInFun); - }); - return newSet; -} - -// Get the formula computing the linearized index of a thread in a block. -isl::aff getLinearizedThreadIdxFormula( - isl::space space, - const Block& block, - const std::string& suffix = "") { - USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ); - std::vector> mappingIds{ - {TX, TX.mappingSize(block)}, - {TY, TY.mappingSize(block)}, - {TZ, TZ.mappingSize(block)}}; - - isl::aff formula = isl::aff(isl::local_space(space)); - - for (int i = (int)mappingIds.size() - 1; i >= 0; --i) { - auto name = mappingIds[i].first.to_str(); - auto dim = space.find_dim_by_name(isl::dim_type::param, name + suffix); - CHECK_LE(0, dim); - auto id = space.get_dim_id(isl::dim_type::param, dim); - isl::aff aff(isl::aff::param_on_domain_space(space, id)); - formula = formula * mappingIds[i].second + aff; + auto nodeToThread = isl::multi_union_pw_aff(space, list); + domainToThread = domainToThread.union_add(nodeToThread); } - return formula; + return domainToThread; } -// Return the constraints ensuring that the points with parameters -// [t0,t1,t2] and [t0',t1',t2'] are in the same warp. -// (where t0 is "t0" + suffix1 and t0' is "t0" + suffix2) -// if suffix1 is "_1" and suffix2 is "_2", the constraint is in the form -// ((t0_1 + a * t1_1 + b * t2_1) / warpSize).floor() -// == ((t0_2 + a' * t1_2 + b' * t1_2) / warpSize).floor() -// with t0_1 + a * t1_1 + b * t2_1 the linearized formula of the thread index. -// This function returns a set because it might change in the future, -// and take into account the blocks. -isl::set getSameWarpConstraints( - isl::space space, - const std::string& suffix1, - const std::string& suffix2, +/* + * Construct a mapping + * + * block[x] -> warp[floor((x)/warpSize)] + * block[x, y] -> warp[floor((x + s_x * (y))/warpSize)] + * block[x, y, z] -> warp[floor((x + s_x * (y + s_y * (z)))/warpSize)] + * + * uniquely mapping thread identifiers that belong to the same warp + * (of size "warpSize") to a warp identifier, + * based on the thread sizes s_x, s_y up to s_z in "block". + */ +isl::multi_aff constructThreadToWarp( + isl::ctx ctx, const unsigned warpSize, const Block& block) { - USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ); - std::vector> mappingIds{ - {TX, TX.mappingSize(block)}, - {TY, TY.mappingSize(block)}, - {TZ, TZ.mappingSize(block)}}; - - auto formula1 = getLinearizedThreadIdxFormula(space, block, suffix1); - auto formula2 = getLinearizedThreadIdxFormula(space, block, suffix2); + auto space = isl::space(ctx, 0); + auto id = isl::id(ctx, kBlock); + auto blockSpace = space.named_set_from_params_id(id, block.view.size()); + auto warpSpace = space.named_set_from_params_id(isl::id(ctx, kWarp), 1); + auto aff = isl::aff::zero_on_domain(blockSpace); + + auto nThread = block.view.size(); + auto identity = isl::multi_aff::identity(blockSpace.map_from_set()); + for (int i = nThread - 1; i >= 0; --i) { + aff = aff.scale(isl::val(ctx, block.view[i])); + aff = aff.add(identity.get_aff(i)); + } - return ( - isl::aff_set((formula1 / warpSize).floor()) == - (formula2 / warpSize).floor()); + aff = aff.scale_down(isl::val(ctx, warpSize)).floor(); + auto mapSpace = blockSpace.product(warpSpace).unwrap(); + return isl::multi_aff(mapSpace, isl::aff_list(aff)); } } // namespace Scop::SyncLevel MappedScop::findBestSync( detail::ScheduleTree* st1, - detail::ScheduleTree* st2) { + detail::ScheduleTree* st2, + isl::multi_union_pw_aff domainToThread, + isl::multi_union_pw_aff domainToWarp) { // Active points in the two schedule trees auto stRoot = scop_->scheduleRoot(); auto activePoints1 = activeDomainPointsBelow(stRoot, st1); @@ -557,41 +522,16 @@ Scop::SyncLevel MappedScop::findBestSync( return Scop::SyncLevel::None; } - // The domain and the context of the root schedule tree - auto domainAndContext = scop_->domain(); CHECK_LE(1u, scop_->scheduleRoot()->children().size()); auto contextSt = scop_->scheduleRoot()->children()[0]; auto contextElem = contextSt->elemAs(); CHECK(nullptr != contextElem); - domainAndContext = domainAndContext.intersect_params(contextElem->context_); + dependences = dependences.intersect_params(contextElem->context_); - // The domain of both schedule trees filtered by mapping filters, - // and then modified to have different threads and blocks names. - auto domain1 = intersectMappingFilterParams(st1, domainAndContext); - auto domain2 = intersectMappingFilterParams(st2, domainAndContext); - auto suffix1 = "_1"; - auto suffix2 = "_2"; - domain1 = modifyMappingNames(domain1, suffix1); - domain2 = modifyMappingNames(domain2, suffix2); - - // The dependences between the two schedule trees - // with mapping from threads and blocks - auto mappedDependences = - isl::union_map::from_domain_and_range(domain1, domain2); - mappedDependences = mappedDependences.intersect(dependences); - - auto space = mappedDependences.get_space(); - auto sameThreadConstraint = - getSameWarpConstraints(space, suffix1, suffix2, 1, numThreads); - auto sameWarpConstraints = - getSameWarpConstraints(space, suffix1, suffix2, 32, numThreads); - - if (mappedDependences == - mappedDependences.intersect_params(sameThreadConstraint)) { + if (dependences.is_subset(dependences.eq_at(domainToThread))) { return Scop::SyncLevel::None; - } else if ( - mappedDependences == - mappedDependences.intersect_params(sameWarpConstraints)) { + } + if (dependences.is_subset(dependences.eq_at(domainToWarp))) { return Scop::SyncLevel::Warp; } return Scop::SyncLevel::Block; @@ -754,6 +694,10 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) { auto outer = hasOuterSequentialMember(scop_->scheduleRoot(), seq); + auto domainToThread = extractDomainToThread(seq, numThreads.view.size()); + auto threadToWarp = constructThreadToWarp(seq->ctx_, 32, numThreads); + auto domainToWarp = domainToThread.apply(threadToWarp); + std::vector> bestSync( nChildren, std::vector(nChildren + 1)); // Get the synchronization needed between children[i] and children[i+k] @@ -765,7 +709,8 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) { for (size_t i = 0; i < nChildren; ++i) { for (size_t k = 0; k < nChildren; ++k) { auto ik = (i + k) % nChildren; - bestSync[i][k] = (int)findBestSync(children[i], children[ik]); + bestSync[i][k] = (int)findBestSync( + children[i], children[ik], domainToThread, domainToWarp); } } diff --git a/tc/core/polyhedral/cuda/mapped_scop.h b/tc/core/polyhedral/cuda/mapped_scop.h index 6b84b704a..169b4f138 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.h +++ b/tc/core/polyhedral/cuda/mapped_scop.h @@ -163,9 +163,13 @@ class MappedScop { // st1. // This function assumes that it is called before block mapping // and that st1 and st2 are already mapped to threads. + // "domainToThread" and "domainToWarp" map the domain elements + // of st1 and st2 to thread and warp identifiers, respectively. Scop::SyncLevel findBestSync( detail::ScheduleTree* st1, - detail::ScheduleTree* st2); + detail::ScheduleTree* st2, + isl::multi_union_pw_aff domainToThread, + isl::multi_union_pw_aff domainToWarp); public: // Find best configuration of synchronizations in a sequence, minimizing From 6dc1fbca7877804b2984d708a8acfd700ee25158 Mon Sep 17 00:00:00 2001 From: Sven Verdoolaege Date: Tue, 29 May 2018 14:33:18 +0200 Subject: [PATCH 4/4] test --- test/cuda/test_tc_mapper_bugs.cc | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/cuda/test_tc_mapper_bugs.cc b/test/cuda/test_tc_mapper_bugs.cc index ace9a4de2..cfa812400 100644 --- a/test/cuda/test_tc_mapper_bugs.cc +++ b/test/cuda/test_tc_mapper_bugs.cc @@ -24,6 +24,7 @@ #include "tc/core/cuda/cuda_tc_executor.h" #include "tc/core/flags.h" #include "tc/core/polyhedral/exceptions.h" +#include "tc/library/matmul.h" #include "test_harness_aten_cuda.h" @@ -39,6 +40,37 @@ using namespace tc; // "Bug" suffix and fly away in the sun. /////////////////////////////////////////////////////////////////////////////// +TEST(A, B) { + auto TC = makeMatmulTc(); +auto options = +tc::CudaMappingOptions::makeNaiveMappingOptions() + .outerScheduleFusionStrategy(tc::FusionStrategy::Max) + .outerScheduleAllowSkewing(false) + .outerSchedulePositiveOrthant(true) + .intraTileScheduleFusionStrategy(tc::FusionStrategy::Min) + .intraTileScheduleAllowSkewing(false) + .intraTileSchedulePositiveOrthant(true) + .fixParametersBeforeScheduling(false) + .tile(56, 32, 4, 14, 16) + .unroll(16) + .tileImperfectlyNested(false) + .matchLibraryCalls(false) + .mapToThreads(4, 128) + .mapToBlocks(1, 32, 32) + .useSharedMemory(false) + .usePrivateMemory(true) + .unrollCopyShared(false) + .useReadOnlyCache(false); + uint32_t N = 100, K = 400, M = 500; + at::Tensor A = at::CUDA(at::kFloat).rand({N, K}); + at::Tensor B = at::CUDA(at::kFloat).rand({K, M}); + std::vector inputs = {A, B}; + auto pExecutor = + tc::aten::compile(TC, "matmul", inputs, options); + auto outputs = tc::aten::prepareOutputs(TC, "matmul", inputs); + tc::aten::run(*pExecutor, inputs, outputs); +} + std::string makeUniqueName(const std::string& name) { static int count = 0; return name + std::string("_cnt") + std::to_string(++count);