diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 955190ade..11211c0e6 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -431,167 +431,107 @@ bool hasOuterSequentialMember( return false; } -// Intersect the union set with all the mapping -// filters params in the given schedule tree -isl::union_set intersectMappingFilterParams( - detail::ScheduleTree* st, - isl::union_set us) { - if (auto filter = st->elemAsBase()) { - us = us.intersect(filter->filter_); - } +// Name of the space of threads inside a block +constexpr auto kBlock = "block"; +// Name of the space of warps +constexpr auto kWarp = "warp"; - auto children = st->children(); - auto nChildren = children.size(); - if (nChildren == 1) { - us = intersectMappingFilterParams(children[0], us); - } else if (nChildren > 1) { - auto usParent = us; - us = intersectMappingFilterParams(children[0], us); - for (size_t i = 1; i < nChildren; ++i) { - us = us.unite(intersectMappingFilterParams(children[i], usParent)); - } - } - - return us; -} +/* + * Extract a mapping from the domain elements active at "tree" + * to the thread identifiers, where all branches in "tree" + * are assumed to have been mapped to thread identifiers. + * "nThread" is the number of thread identifiers. + * The result lives in a space of the form block[x, ...]. + */ +isl::multi_union_pw_aff extractDomainToThread( + const detail::ScheduleTree* tree, + size_t nThread) { + using namespace polyhedral::detail; -// Change the name of the isl ids tied to threads and blocks -// by adding a suffix -isl::union_set modifyMappingNames( - isl::union_set set, - const std::string suffix) { - USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ); - std::unordered_set identifiers{ - BX, BY, BZ, TX, TY, TZ}; - - auto space = set.get_space(); - for (auto id : identifiers) { - auto name = id.get_name(); - auto dim = space.find_dim_by_name(isl::dim_type::param, id.get_name()); - CHECK_LE(0, dim); - space = space.set_dim_name(isl::dim_type::param, dim, name + suffix); - } - auto newSet = isl::union_set::empty(space); - set.foreach_set([&newSet, &identifiers, &suffix](isl::set setInFun) { - for (auto id : identifiers) { - auto name = id.get_name(); - auto dim = - setInFun.get_space().find_dim_by_name(isl::dim_type::param, name); - CHECK_LE(0, dim); - setInFun = - setInFun.set_dim_name(isl::dim_type::param, dim, name + suffix); + auto space = isl::space(tree->ctx_, 0); + auto empty = isl::union_set::empty(space); + auto id = isl::id(tree->ctx_, kBlock); + space = space.named_set_from_params_id(id, nThread); + auto zero = isl::multi_val::zero(space); + auto domainToThread = isl::multi_union_pw_aff(empty, zero); + + for (auto mapping : tree->collect(tree, ScheduleTreeType::MappingFilter)) { + auto mappingNode = mapping->elemAs(); + auto list = isl::union_pw_aff_list(tree->ctx_, nThread); + for (size_t i = 0; i < nThread; ++i) { + auto threadId = mapping::ThreadId::makeId(i); + auto threadMap = mappingNode->mapping.at(threadId); + list = list.add(threadMap); } - newSet = newSet.unite(setInFun); - }); - return newSet; -} - -// Get the formula computing the linearized index of a thread in a block. -isl::aff getLinearizedThreadIdxFormula( - isl::space space, - const Block& block, - const std::string& suffix = "") { - USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ); - std::vector> mappingIds{ - {TX, TX.mappingSize(block)}, - {TY, TY.mappingSize(block)}, - {TZ, TZ.mappingSize(block)}}; - - isl::aff formula = isl::aff(isl::local_space(space)); - - for (int i = (int)mappingIds.size() - 1; i >= 0; --i) { - auto name = mappingIds[i].first.to_str(); - auto dim = space.find_dim_by_name(isl::dim_type::param, name + suffix); - CHECK_LE(0, dim); - auto id = space.get_dim_id(isl::dim_type::param, dim); - isl::aff aff(isl::aff::param_on_domain_space(space, id)); - formula = formula * mappingIds[i].second + aff; + auto nodeToThread = isl::multi_union_pw_aff(space, list); + domainToThread = domainToThread.union_add(nodeToThread); } - return formula; + return domainToThread; } -// Return the constraints ensuring that the points with parameters -// [t0,t1,t2] and [t0',t1',t2'] are in the same warp. -// (where t0 is "t0" + suffix1 and t0' is "t0" + suffix2) -// if suffix1 is "_1" and suffix2 is "_2", the constraint is in the form -// ((t0_1 + a * t1_1 + b * t2_1) / warpSize).floor() -// == ((t0_2 + a' * t1_2 + b' * t1_2) / warpSize).floor() -// with t0_1 + a * t1_1 + b * t2_1 the linearized formula of the thread index. -// This function returns a set because it might change in the future, -// and take into account the blocks. -isl::set getSameWarpConstraints( - isl::space space, - const std::string& suffix1, - const std::string& suffix2, +/* + * Construct a mapping + * + * block[x] -> warp[floor((x)/warpSize)] + * block[x, y] -> warp[floor((x + s_x * (y))/warpSize)] + * block[x, y, z] -> warp[floor((x + s_x * (y + s_y * (z)))/warpSize)] + * + * uniquely mapping thread identifiers that belong to the same warp + * (of size "warpSize") to a warp identifier, + * based on the thread sizes s_x, s_y up to s_z in "block". + */ +isl::multi_aff constructThreadToWarp( + isl::ctx ctx, const unsigned warpSize, const Block& block) { - USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ); - std::vector> mappingIds{ - {TX, TX.mappingSize(block)}, - {TY, TY.mappingSize(block)}, - {TZ, TZ.mappingSize(block)}}; - - auto formula1 = getLinearizedThreadIdxFormula(space, block, suffix1); - auto formula2 = getLinearizedThreadIdxFormula(space, block, suffix2); + auto space = isl::space(ctx, 0); + auto id = isl::id(ctx, kBlock); + auto blockSpace = space.named_set_from_params_id(id, block.view.size()); + auto warpSpace = space.named_set_from_params_id(isl::id(ctx, kWarp), 1); + auto aff = isl::aff::zero_on_domain(blockSpace); + + auto nThread = block.view.size(); + auto identity = isl::multi_aff::identity(blockSpace.map_from_set()); + for (int i = nThread - 1; i >= 0; --i) { + aff = aff.scale(isl::val(ctx, block.view[i])); + aff = aff.add(identity.get_aff(i)); + } - return ( - isl::aff_set((formula1 / warpSize).floor()) == - (formula2 / warpSize).floor()); + aff = aff.scale_down(isl::val(ctx, warpSize)).floor(); + auto mapSpace = blockSpace.product(warpSpace).unwrap(); + return isl::multi_aff(mapSpace, isl::aff_list(aff)); } } // namespace Scop::SyncLevel MappedScop::findBestSync( detail::ScheduleTree* st1, - detail::ScheduleTree* st2) { + detail::ScheduleTree* st2, + isl::multi_union_pw_aff domainToThread, + isl::multi_union_pw_aff domainToWarp) { // Active points in the two schedule trees auto stRoot = scop_->scheduleRoot(); auto activePoints1 = activeDomainPointsBelow(stRoot, st1); auto activePoints2 = activeDomainPointsBelow(stRoot, st2); // The dependences between the two schedule trees - auto dependences = - isl::union_map::from_domain_and_range(activePoints1, activePoints2); - dependences = dependences.intersect(scop_->dependences); + auto dependences = scop_->dependences; + dependences = dependences.intersect_domain(activePoints1); + dependences = dependences.intersect_range(activePoints2); if (dependences.is_empty()) { return Scop::SyncLevel::None; } - // The domain and the context of the root schedule tree - auto domainAndContext = scop_->domain(); CHECK_LE(1u, scop_->scheduleRoot()->children().size()); auto contextSt = scop_->scheduleRoot()->children()[0]; auto contextElem = contextSt->elemAs(); CHECK(nullptr != contextElem); - domainAndContext = domainAndContext.intersect_params(contextElem->context_); + dependences = dependences.intersect_params(contextElem->context_); - // The domain of both schedule trees filtered by mapping filters, - // and then modified to have different threads and blocks names. - auto domain1 = intersectMappingFilterParams(st1, domainAndContext); - auto domain2 = intersectMappingFilterParams(st2, domainAndContext); - auto suffix1 = "_1"; - auto suffix2 = "_2"; - domain1 = modifyMappingNames(domain1, suffix1); - domain2 = modifyMappingNames(domain2, suffix2); - - // The dependences between the two schedule trees - // with mapping from threads and blocks - auto mappedDependences = - isl::union_map::from_domain_and_range(domain1, domain2); - mappedDependences = mappedDependences.intersect(dependences); - - auto space = mappedDependences.get_space(); - auto sameThreadConstraint = - getSameWarpConstraints(space, suffix1, suffix2, 1, numThreads); - auto sameWarpConstraints = - getSameWarpConstraints(space, suffix1, suffix2, 32, numThreads); - - if (mappedDependences == - mappedDependences.intersect_params(sameThreadConstraint)) { + if (dependences.is_subset(dependences.eq_at(domainToThread))) { return Scop::SyncLevel::None; - } else if ( - mappedDependences == - mappedDependences.intersect_params(sameWarpConstraints)) { + } + if (dependences.is_subset(dependences.eq_at(domainToWarp))) { return Scop::SyncLevel::Warp; } return Scop::SyncLevel::Block; @@ -754,6 +694,10 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) { auto outer = hasOuterSequentialMember(scop_->scheduleRoot(), seq); + auto domainToThread = extractDomainToThread(seq, numThreads.view.size()); + auto threadToWarp = constructThreadToWarp(seq->ctx_, 32, numThreads); + auto domainToWarp = domainToThread.apply(threadToWarp); + std::vector> bestSync( nChildren, std::vector(nChildren + 1)); // Get the synchronization needed between children[i] and children[i+k] @@ -765,7 +709,8 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) { for (size_t i = 0; i < nChildren; ++i) { for (size_t k = 0; k < nChildren; ++k) { auto ik = (i + k) % nChildren; - bestSync[i][k] = (int)findBestSync(children[i], children[ik]); + bestSync[i][k] = (int)findBestSync( + children[i], children[ik], domainToThread, domainToWarp); } } diff --git a/tc/core/polyhedral/cuda/mapped_scop.h b/tc/core/polyhedral/cuda/mapped_scop.h index 6b84b704a..169b4f138 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.h +++ b/tc/core/polyhedral/cuda/mapped_scop.h @@ -163,9 +163,13 @@ class MappedScop { // st1. // This function assumes that it is called before block mapping // and that st1 and st2 are already mapped to threads. + // "domainToThread" and "domainToWarp" map the domain elements + // of st1 and st2 to thread and warp identifiers, respectively. Scop::SyncLevel findBestSync( detail::ScheduleTree* st1, - detail::ScheduleTree* st2); + detail::ScheduleTree* st2, + isl::multi_union_pw_aff domainToThread, + isl::multi_union_pw_aff domainToWarp); public: // Find best configuration of synchronizations in a sequence, minimizing diff --git a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc index c51e09a96..1c174e0ae 100644 --- a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc +++ b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc @@ -46,8 +46,8 @@ bool isThreadMapping(const detail::ScheduleTree* tree) { using namespace detail; if (auto filterNode = tree->elemAs()) { - for (auto id : filterNode->mappingIds) { - if (isThreadId(id)) { + for (auto& kvp : filterNode->mapping) { + if (isThreadId(kvp.first)) { return true; } } diff --git a/tc/core/polyhedral/schedule_print.cc b/tc/core/polyhedral/schedule_print.cc index b1371649f..d12b43541 100644 --- a/tc/core/polyhedral/schedule_print.cc +++ b/tc/core/polyhedral/schedule_print.cc @@ -202,8 +202,8 @@ std::ostream& ScheduleTreeElemFilter::write(std::ostream& os) const { std::ostream& ScheduleTreeElemMappingFilter::write(std::ostream& os) const { WS w; os << w.tab() << "mapping_filter(ids("; - for (const auto& id : mappingIds) { - os << id << ", "; + for (auto& kvp : mapping) { + os << kvp.first << ", "; } os << ")"; for (const auto& u : filter_.get_set_list()) { diff --git a/tc/core/polyhedral/schedule_tree-inl.h b/tc/core/polyhedral/schedule_tree-inl.h index b97a2ff1d..f3e4e614d 100644 --- a/tc/core/polyhedral/schedule_tree-inl.h +++ b/tc/core/polyhedral/schedule_tree-inl.h @@ -23,15 +23,19 @@ inline ScheduleTreeUPtr ScheduleTree::makeMappingFilter( const std::vector& mappedIds, isl::union_pw_aff_list mappedAffs, std::vector&& children) { - std::vector ids; - for (auto id : mappedIds) { - ids.push_back(id); + CHECK_EQ(mappedIds.size(), static_cast(mappedAffs.n())) + << "expected as many mapped ids as affs"; + ScheduleTreeElemMappingFilter::Mapping mapping; + for (size_t i = 0, n = mappedAffs.n(); i < n; ++i) { + mapping.emplace(mappedIds.at(i), mappedAffs.get(i)); } - CHECK_GE(ids.size(), 1u) << "empty mapping"; + CHECK_GE(mapping.size(), 1u) << "empty mapping"; + CHECK_EQ(mappedIds.size(), mapping.size()) + << "some id is used more than once in the mapping"; auto ctx = mappedIds[0].get_ctx(); ScheduleTreeUPtr res(new ScheduleTree(ctx)); res->elem_ = std::unique_ptr( - new ScheduleTreeElemMappingFilter(ids, mappedAffs)); + new ScheduleTreeElemMappingFilter(mapping)); res->type_ = ScheduleTreeType::MappingFilter; res->appendChildren(std::move(children)); return res; diff --git a/tc/core/polyhedral/schedule_tree_elem.h b/tc/core/polyhedral/schedule_tree_elem.h index 2f0e2921f..f1ab3153b 100644 --- a/tc/core/polyhedral/schedule_tree_elem.h +++ b/tc/core/polyhedral/schedule_tree_elem.h @@ -17,7 +17,7 @@ #include #include -#include +#include #include #include "tc/external/isl.h" @@ -139,34 +139,29 @@ struct ScheduleTreeElemFilter : public ScheduleTreeElemBase { }; struct ScheduleTreeElemMappingFilter : public ScheduleTreeElemFilter { + using Mapping = std::unordered_map< + mapping::MappingId, + isl::union_pw_aff, + typename mapping::MappingId::Hash>; static constexpr std::initializer_list NodeDerivedTypes{detail::ScheduleTreeType::None}; static constexpr detail::ScheduleTreeType NodeType = detail::ScheduleTreeType::MappingFilter; ScheduleTreeElemMappingFilter() = delete; ScheduleTreeElemMappingFilter(const ScheduleTreeElemMappingFilter& eb) - : ScheduleTreeElemFilter(eb.filter_), mappingIds(eb.mappingIds) {} - ScheduleTreeElemMappingFilter( - const std::vector& mappedIds, - isl::union_pw_aff_list mappedAffs) - : ScheduleTreeElemFilter(isl::union_set()), - mappingIds(mappedIds.begin(), mappedIds.end()) { - // Check that ids are unique. - CHECK_EQ(mappedIds.size(), mappingIds.size()) - << "some id is used more than once in the mapping"; + : ScheduleTreeElemFilter(eb.filter_), mapping(eb.mapping) {} + ScheduleTreeElemMappingFilter(const Mapping& mapping) + : ScheduleTreeElemFilter(isl::union_set()), mapping(mapping) { + CHECK_GT(mapping.size(), 0u) << "empty mapping filter"; - CHECK_EQ(mappedIds.size(), static_cast(mappedAffs.n())) - << "expected as many mapped ids as affs"; - CHECK_GT(mappedIds.size(), 0u) << "empty mapping filter"; - - auto domain = mappedAffs.get(0).domain(); - for (size_t i = 1, n = mappedAffs.n(); i < n; ++i) { - CHECK(domain.is_equal(mappedAffs.get(i).domain())); + auto domain = mapping.cbegin()->second.domain(); + for (auto& kvp : mapping) { + CHECK(domain.is_equal(kvp.second.domain())); } filter_ = domain.universe(); - for (size_t i = 0, n = mappedAffs.n(); i < n; ++i) { - auto upa = mappedAffs.get(i); - auto id = mappedIds.at(i); + for (auto& kvp : mapping) { + auto upa = kvp.second; + auto id = kvp.first; // Create mapping filter by equating the // parameter mappedIds[i] to the "i"-th affine function. upa = upa.sub(isl::union_pw_aff::param_on_domain(domain.universe(), id)); @@ -183,9 +178,8 @@ struct ScheduleTreeElemMappingFilter : public ScheduleTreeElemFilter { return NodeType; } - const std:: - unordered_set - mappingIds; + // Mapping from identifiers to affine functions on domain elements. + const Mapping mapping; }; struct ScheduleTreeElemSequence : public ScheduleTreeElemBase { diff --git a/test/cuda/test_tc_mapper_bugs.cc b/test/cuda/test_tc_mapper_bugs.cc index ace9a4de2..cfa812400 100644 --- a/test/cuda/test_tc_mapper_bugs.cc +++ b/test/cuda/test_tc_mapper_bugs.cc @@ -24,6 +24,7 @@ #include "tc/core/cuda/cuda_tc_executor.h" #include "tc/core/flags.h" #include "tc/core/polyhedral/exceptions.h" +#include "tc/library/matmul.h" #include "test_harness_aten_cuda.h" @@ -39,6 +40,37 @@ using namespace tc; // "Bug" suffix and fly away in the sun. /////////////////////////////////////////////////////////////////////////////// +TEST(A, B) { + auto TC = makeMatmulTc(); +auto options = +tc::CudaMappingOptions::makeNaiveMappingOptions() + .outerScheduleFusionStrategy(tc::FusionStrategy::Max) + .outerScheduleAllowSkewing(false) + .outerSchedulePositiveOrthant(true) + .intraTileScheduleFusionStrategy(tc::FusionStrategy::Min) + .intraTileScheduleAllowSkewing(false) + .intraTileSchedulePositiveOrthant(true) + .fixParametersBeforeScheduling(false) + .tile(56, 32, 4, 14, 16) + .unroll(16) + .tileImperfectlyNested(false) + .matchLibraryCalls(false) + .mapToThreads(4, 128) + .mapToBlocks(1, 32, 32) + .useSharedMemory(false) + .usePrivateMemory(true) + .unrollCopyShared(false) + .useReadOnlyCache(false); + uint32_t N = 100, K = 400, M = 500; + at::Tensor A = at::CUDA(at::kFloat).rand({N, K}); + at::Tensor B = at::CUDA(at::kFloat).rand({K, M}); + std::vector inputs = {A, B}; + auto pExecutor = + tc::aten::compile(TC, "matmul", inputs, options); + auto outputs = tc::aten::prepareOutputs(TC, "matmul", inputs); + tc::aten::run(*pExecutor, inputs, outputs); +} + std::string makeUniqueName(const std::string& name) { static int count = 0; return name + std::string("_cnt") + std::to_string(++count);