From d1017b828b9ab6e4c2b47e68e9fc03bf59694848 Mon Sep 17 00:00:00 2001 From: math-fehr Date: Thu, 19 Apr 2018 14:23:06 +0200 Subject: [PATCH 01/16] Added functions to use grid synchronization. As for now, there is a flag to disable grid synchronization. When the grid synchronization flag is enabled, tc will try to launch the kernel with a cooperative launch. The cooperative launch requires co-residency in the device, which is not currently checked. It will thus fail when there is too much blocks or threads. --- tc/core/CMakeLists.txt | 3 +- tc/core/constants.h | 1 + tc/core/cuda/cuda_libraries.h | 6 +++ tc/core/cuda/cuda_rtc.cc | 66 +++++++++++++++++++------ tc/core/flags.cc | 1 + tc/core/flags.h | 1 + tc/core/polyhedral/cuda/codegen.cc | 2 + tc/core/polyhedral/cuda/mapped_scop.cc | 2 +- tc/core/polyhedral/scop.h | 19 ++++++- tc/version/CMakeLists.txt | 13 ++++- tc/version/cuda_version.cc.in | 5 ++ tc/version/cuda_version.h | 19 +++++++ test/cuda/CMakeLists.txt | 2 +- test/cuda/test_basic_gpu.cc | 68 ++++++++++++++++++++++++-- 14 files changed, 185 insertions(+), 23 deletions(-) create mode 100644 tc/version/cuda_version.cc.in create mode 100644 tc/version/cuda_version.h diff --git a/tc/core/CMakeLists.txt b/tc/core/CMakeLists.txt index 4435f1533..ff64afd31 100644 --- a/tc/core/CMakeLists.txt +++ b/tc/core/CMakeLists.txt @@ -45,7 +45,7 @@ target_link_libraries( -lLLVM tc_lang - tc_version + tc_cuda_version tc_proto ) if (WITH_BINDINGS) @@ -176,6 +176,7 @@ if (WITH_CUDA) tc_lang tc_version + tc_cuda_version tc_proto tc_core ) diff --git a/tc/core/constants.h b/tc/core/constants.h index 8ec277583..92c779209 100644 --- a/tc/core/constants.h +++ b/tc/core/constants.h @@ -29,6 +29,7 @@ constexpr auto kReadIdName = "read"; constexpr auto kWriteIdName = "write"; constexpr auto kSyncIdPrefix = "_sync_"; constexpr auto kWarpSyncIdPrefix = "_warpSync_"; +constexpr auto kGridSyncIdPrefix = "_gridSync_"; } // namespace polyhedral } // namespace tc diff --git a/tc/core/cuda/cuda_libraries.h b/tc/core/cuda/cuda_libraries.h index c2fdb1de8..2f745cef6 100644 --- a/tc/core/cuda/cuda_libraries.h +++ b/tc/core/cuda/cuda_libraries.h @@ -61,6 +61,12 @@ __device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {} #endif )C"; +constexpr auto gridSyncFunctions = R"C( +__device__ void __syncgrid() { + cudaCGSynchronize(cudaCGGetIntrinsicHandle(cudaCGScopeGrid),0); +} +)C"; + constexpr auto mathFunctionDecl = R"C( // BEGIN MATH FUNCTIONS FROM CUDA diff --git a/tc/core/cuda/cuda_rtc.cc b/tc/core/cuda/cuda_rtc.cc index 412e397dc..0b959a4a2 100644 --- a/tc/core/cuda/cuda_rtc.cc +++ b/tc/core/cuda/cuda_rtc.cc @@ -25,6 +25,7 @@ #include "tc/core/cuda/cuda_rtc.h" #include "tc/core/flags.h" #include "tc/core/scope_guard.h" +#include "tc/version/cuda_version.h" namespace tc { std::mutex nvrtc_mutex; @@ -88,6 +89,9 @@ std::unique_ptr CudaRTCFunction::Compile( "-DNVRTC_CUB=1", cudaHome.c_str(), cubHome.c_str()}; + if (FLAGS_grid_sync) { + nvrtcts.push_back("--relocatable-device-code=true"); + } if (FLAGS_debug_cuda) { nvrtcts.push_back(nvrtc_debug_opts[0]); nvrtcts.push_back(nvrtc_debug_opts[1]); @@ -143,8 +147,28 @@ Duration CudaRTCFunction::Launch( if (perGpuModule_.count(dev) == 0) { CUmodule module; CUfunction function; - TC_CUDA_DRIVERAPI_ENFORCE( - cuModuleLoadDataEx(&module, nvrtc_ptx.data(), 0, 0, 0)); + if (FLAGS_grid_sync) { + CUlinkState linkState; + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkCreate(0, 0, 0, &linkState)); + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkAddFile( + linkState, CU_JIT_INPUT_LIBRARY, cuda_libdevrt_path, 0, 0, 0)); + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkAddData( + linkState, + CU_JIT_INPUT_PTX, + (void*)nvrtc_ptx.data(), + nvrtc_ptx.size(), + "device_code.ptx", + 0, + 0, + 0)); + size_t cubinSize; + void* cubin; + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkComplete(linkState, &cubin, &cubinSize)); + TC_CUDA_DRIVERAPI_ENFORCE(cuModuleLoadData(&module, cubin)); + } else { + TC_CUDA_DRIVERAPI_ENFORCE( + cuModuleLoadDataEx(&module, nvrtc_ptx.data(), 0, 0, 0)); + } perGpuModule_.emplace(dev, module); TC_CUDA_DRIVERAPI_ENFORCE( cuModuleGetFunction(&function, module, specializedName.c_str())); @@ -174,18 +198,32 @@ Duration CudaRTCFunction::Launch( unsigned int by = block[1]; unsigned int bz = block[2]; auto launch = [&]() { - TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel( - perGpuKernel_.at(dev), - gx, - gy, - gz, - bx, - by, - bz, - shared_mem, - stream, - args_voidp.data(), - 0)); + if (FLAGS_grid_sync) { + TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchCooperativeKernel( + perGpuKernel_.at(dev), + gx, + gy, + gz, + bx, + by, + bz, + shared_mem, + stream, + args_voidp.data())); + } else { + TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel( + perGpuKernel_.at(dev), + gx, + gy, + gz, + bx, + by, + bz, + shared_mem, + stream, + args_voidp.data(), + 0)); + } }; if (not profile) { diff --git a/tc/core/flags.cc b/tc/core/flags.cc index 80c9e5ec5..d1f342426 100644 --- a/tc/core/flags.cc +++ b/tc/core/flags.cc @@ -37,6 +37,7 @@ DEFINE_bool( "Print debug spew for the tc_mapper like cuda code, mapping options etc"); DEFINE_bool(dump_cuda, false, "Print the generated source"); DEFINE_bool(dump_ptx, false, "Dump the generated PTX"); +DEFINE_bool(grid_sync, true, "Use the grid sync feature."); // CPU codegen options DEFINE_bool(llvm_dump_before_opt, false, "Print IR before optimization"); diff --git a/tc/core/flags.h b/tc/core/flags.h index c748759a4..86abcc1ab 100644 --- a/tc/core/flags.h +++ b/tc/core/flags.h @@ -30,6 +30,7 @@ DECLARE_bool(debug_cuda); DECLARE_bool(debug_tuner); DECLARE_bool(dump_cuda); DECLARE_bool(dump_ptx); +DECLARE_bool(grid_sync); // llvm codegen DECLARE_bool(llvm_dump_before_opt); diff --git a/tc/core/polyhedral/cuda/codegen.cc b/tc/core/polyhedral/cuda/codegen.cc index 032d5dcab..60eccbc06 100644 --- a/tc/core/polyhedral/cuda/codegen.cc +++ b/tc/core/polyhedral/cuda/codegen.cc @@ -450,6 +450,8 @@ void AstPrinter::emitStmt(isl::ast_node_user node) { context_.ss << "__syncthreads();" << std::endl; } else if (context_.scop().isWarpSyncId(stmtId)) { context_.ss << "__syncwarp();" << std::endl; + } else if (context_.scop().isGridSyncId(stmtId)) { + context_.ss << "__syncgrid();" << std::endl; } else if ( stmtId.get_name() == kReadIdName || stmtId.get_name() == kWriteIdName) { emitCopyStmt(statementContext); diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 6be0812fe..58e045250 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -906,7 +906,7 @@ std::tuple MappedScop::codegen( std::stringstream code; code << code::cpp::boundsAsTemplate << code::c::types << code::c::defines; - code << code::c::warpSyncFunctions; + code << code::c::warpSyncFunctions << code::c::gridSyncFunctions; code << std::endl; if (useReadOnlyCache) { code << code::cuda::ldg; diff --git a/tc/core/polyhedral/scop.h b/tc/core/polyhedral/scop.h index b305315d8..cd611bbe6 100644 --- a/tc/core/polyhedral/scop.h +++ b/tc/core/polyhedral/scop.h @@ -195,7 +195,7 @@ struct Scop { isl::id updateId); // The different level of synchronization. - enum class SyncLevel : int { None = 0, Warp = 1, Block = 2 }; + enum class SyncLevel : int { None = 0, Warp = 1, Block = 2, Grid = 3 }; // Given a sequence node in the schedule tree, insert // synchronization before the child at position "pos". @@ -228,6 +228,10 @@ struct Scop { static size_t count = 0; return count++; } + size_t gridSyncUID() const { + static size_t count = 0; + return count++; + } // Make the synchronization id corresponding to the synchronization level. // The level should not be None. @@ -239,6 +243,9 @@ struct Scop { case SyncLevel::Block: return makeSyncId(); break; + case SyncLevel::Grid: + return makeGridSyncId(); + break; default: TC_CHECK(level != SyncLevel::None); return isl::id(); @@ -256,6 +263,12 @@ struct Scop { ctx, std::string(kWarpSyncIdPrefix) + std::to_string(warpSyncUID())); } + isl::id makeGridSyncId() const { + auto ctx = domain().get_ctx(); + return isl::id( + ctx, std::string(kGridSyncIdPrefix) + std::to_string(gridSyncUID())); + } + // Check if the id has a name with the expected prefix, followed by a long // integer. static bool isIdWithExpectedPrefix( @@ -282,6 +295,10 @@ struct Scop { return isIdWithExpectedPrefix(id, kWarpSyncIdPrefix); } + static bool isGridSyncId(isl::id id) { + return isIdWithExpectedPrefix(id, kGridSyncIdPrefix); + } + static isl::id makeRefId(isl::ctx ctx) { static thread_local size_t count = 0; return isl::id(ctx, std::string("__tc_ref_") + std::to_string(count++)); diff --git a/tc/version/CMakeLists.txt b/tc/version/CMakeLists.txt index cbf3ccfe2..1ad1e61a3 100644 --- a/tc/version/CMakeLists.txt +++ b/tc/version/CMakeLists.txt @@ -2,8 +2,19 @@ get_git_head_revision(GIT_REFSPEC GIT_HASH) execute_process(COMMAND ${GIT_EXECUTABLE} describe --all --long HEAD OUTPUT_VARIABLE GIT_TAG) SET(GIT_DESCRIPTION "${GIT_REFSPEC}${GIT_HASH}") configure_file(version.cc.in version.cc) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_library(tc_version STATIC ${CMAKE_CURRENT_BINARY_DIR}/version.cc) SET_TARGET_PROPERTIES(tc_version PROPERTIES LINKER_LANGUAGE CXX) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(tc_version INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) + +if (WITH_CUDA) + find_library(CUDA_LIBDEVRT_PATH libcudadevrt.a + HINTS ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64) + SET(CUDA_LIBDEVRT_PATH ${CUDA_LIBDEVRT_PATH}) + configure_file(cuda_version.cc.in cuda_version.cc) + add_library(tc_cuda_version STATIC ${CMAKE_CURRENT_BINARY_DIR}/cuda_version.cc) + SET_TARGET_PROPERTIES(tc_cuda_version PROPERTIES LINKER_LANGUAGE CXX) + target_include_directories(tc_cuda_version INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) +endif() diff --git a/tc/version/cuda_version.cc.in b/tc/version/cuda_version.cc.in new file mode 100644 index 000000000..c1915efb3 --- /dev/null +++ b/tc/version/cuda_version.cc.in @@ -0,0 +1,5 @@ +#include "cuda_version.h" + +namespace tc { +const char* cuda_libdevrt_path = "@CUDA_LIBDEVRT_PATH@"; +} diff --git a/tc/version/cuda_version.h b/tc/version/cuda_version.h new file mode 100644 index 000000000..00da80821 --- /dev/null +++ b/tc/version/cuda_version.h @@ -0,0 +1,19 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +namespace tc { +extern const char* cuda_libdevrt_path; +} diff --git a/test/cuda/CMakeLists.txt b/test/cuda/CMakeLists.txt index 564464388..72f106ffd 100644 --- a/test/cuda/CMakeLists.txt +++ b/test/cuda/CMakeLists.txt @@ -12,7 +12,7 @@ set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) ################################################################################ add_executable(test_basic_gpu test_basic_gpu.cc) add_test(test_basic_gpu test_basic_gpu) -target_link_libraries(test_basic_gpu ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARIES} ${CUDA_NVRTC_LIBRARIES} ${GOOGLE_LIBRARIES} pthread ) +target_link_libraries(test_basic_gpu ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARIES} ${CUDA_NVRTC_LIBRARIES} ${GOOGLE_LIBRARIES} pthread tc_cuda_version) ################################################################################ # Core GPU library tests diff --git a/test/cuda/test_basic_gpu.cc b/test/cuda/test_basic_gpu.cc index 4b856820e..e6815b29e 100644 --- a/test/cuda/test_basic_gpu.cc +++ b/test/cuda/test_basic_gpu.cc @@ -27,10 +27,12 @@ #include #include "tc/core/cuda/cuda.h" +#include "tc/version/cuda_version.h" std::vector jitCompile( std::string cuda, - std::vector extraCompileOptions = std::vector{}) { + std::vector extraCompileOptions = std::vector{}, + bool useGridSync = false) { // Actually do the compiling. nvrtcProgram prog; TC_NVRTC_CHECK( @@ -58,6 +60,9 @@ std::vector jitCompile( "-DNVRTC_CUB=1", "-lineinfo", cudaHome.c_str()}; + if (useGridSync) { + nvrtcts.push_back("--relocatable-device-code=true"); + } for (auto o : extraCompileOptions) { nvrtcts.push_back(o); } @@ -83,7 +88,7 @@ std::vector jitCompile( return PTX; } -void loadUnload(const std::string& ptx) { +void loadUnload(const std::string& ptx, bool useGridSync = false) { CUdevice cuDevice; CUcontext context; TC_CUDA_DRIVERAPI_ENFORCE(cuInit(0)); @@ -92,8 +97,28 @@ void loadUnload(const std::string& ptx) { CUmodule m; CUfunction k; - TC_CUDA_DRIVERAPI_ENFORCE(cuModuleLoadDataEx(&m, ptx.c_str(), 0, 0, 0)); - TC_CUDA_DRIVERAPI_ENFORCE(cuModuleGetFunction(&k, m, "foo")); + if (useGridSync) { + CUlinkState linkState; + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkCreate(0, 0, 0, &linkState)); + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkAddFile( + linkState, CU_JIT_INPUT_LIBRARY, tc::cuda_libdevrt_path, 0, 0, 0)); + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkAddData( + linkState, + CU_JIT_INPUT_PTX, + (void*)ptx.data(), + ptx.size(), + "device_code.ptx", + 0, + 0, + 0)); + size_t cubinSize; + void* cubin; + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkComplete(linkState, &cubin, &cubinSize)); + TC_CUDA_DRIVERAPI_ENFORCE(cuModuleLoadData(&m, cubin)); + } else { + TC_CUDA_DRIVERAPI_ENFORCE(cuModuleLoadDataEx(&m, ptx.c_str(), 0, 0, 0)); + TC_CUDA_DRIVERAPI_ENFORCE(cuModuleGetFunction(&k, m, "foo")); + } TC_CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(m)); } @@ -113,6 +138,21 @@ __global__ void foo(int N) EXPECT_NE(std::string::npos, ptx.find(expected)); } +TEST(BasicGpuTest, NvrtcGridSync) { + auto PTX = jitCompile(R"CUDA( +extern "C" { +__global__ void foo(int N) +{ + assert(N == 1); +} +})CUDA", {"-G"}, true); + + std::string ptx(PTX.data()); + loadUnload(ptx, true); + auto expected = R"E(.visible .entry foo()E"; + EXPECT_NE(std::string::npos, ptx.find(expected)); +} + TEST(BasicGpuTest, CubReduce) { std::string path(CUB_HOME); std::string include = std::string("-I ") + path; @@ -149,6 +189,26 @@ __global__ void bar(float* o, const float* i) { EXPECT_NE(std::string::npos, ptx.find(expected)); } +TEST(BasicGpuTest, GridSync) { + auto PTX = jitCompile( + R"CUDA( +__device__ void __syncgrid() { + cudaCGSynchronize(cudaCGGetIntrinsicHandle(cudaCGScopeGrid),0); +} + +extern "C" { +__global__ void foo() { + __syncgrid(); +} +})CUDA", + {"-G"}, + true); + std::string ptx(PTX.data()); + loadUnload(ptx, true); + auto expected = R"E(.visible .entry foo()E"; + EXPECT_NE(std::string::npos, ptx.find(expected)); +} + namespace { // Mark the function argument as __restrict__ depending on the flag. std::string makeFuncWithOptionalRestrict(bool useRestrict) { From 795b542b77e29384ba8c13479c3624cfe83a1fed Mon Sep 17 00:00:00 2001 From: math-fehr Date: Thu, 19 Apr 2018 17:24:09 +0200 Subject: [PATCH 02/16] Search for multiple outermost coincident bands. This is a work in progress, and it currently lacks synchronization between the different parts. Instead of looking for an unique outermost band, looking for multiple outermost coincident bands can result in more degrees of parallelism. For instance, in the Copy2 example in test_cuda_mapper, this change result in the use of blocks, whereas it used only one block before. These change are only legal if there is a way to synchronize the different parts. This is possible thanks to grid synchronization (available only in CUDA 9 however). The coincident outer bands are the bands that are permutable and that have a dimension that is coincident. The outermosts coincident outer bands are the coincident outer bands that do not have a ancestor which is a coincident outer band. Some outermost coincident outer bands are added in the most higher level of the tree possible, to have another property: All leafs has an outermost coincident outer bands (which is unique by definition). By adding these bands in the most higher level of the tree possible, we add the minimum amount of these bands. Adding these bands are not yet necessary, but will be when adding synchronizations. They will be useful to reduce the number of grid synchronizations made. --- tc/core/polyhedral/cuda/mapped_scop.cc | 94 ++++++++++------ tc/core/polyhedral/cuda/mapped_scop.h | 4 + .../cuda/memory_promotion_heuristic.cc | 15 ++- .../cuda/memory_promotion_heuristic.h | 7 +- tc/core/polyhedral/schedule_tree.cc | 3 + tc/core/polyhedral/scop.cc | 106 +++++++++++++++++- tc/core/polyhedral/scop.h | 7 ++ test/test_cuda_mapper.cc | 4 +- test/test_cuda_mapper_memory_promotion.cc | 3 +- 9 files changed, 198 insertions(+), 45 deletions(-) diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 58e045250..0ded66556 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -978,17 +978,25 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( // 3. Tile TC_CHECK_LT(0u, generic.tiling.size()) << "Must pass tile vector with >= 1 tile sizes"; - auto outerBand = scop->tileOuterBand(generic.tiling); + std::vector tiledBands; + if (FLAGS_grid_sync) { + tiledBands = scop->tileOuterCoincidentBands(generic.tiling); + } else { + tiledBands = {scop->tileOuterBand(generic.tiling)}; + } // 4. Optionally reschedule if point loops need a different strategy than // tile loops - if (generic.outerScheduleOptions != generic.intraTileScheduleOptions) { - scop->reschedule(outerBand->child({0}), generic.intraTileScheduleOptions); - LOG_IF(INFO, FLAGS_debug_tc_mapper) - << "After intra-tile rescheduling:" << std::endl - << *mappedScop->schedule(); + for (auto outerBand : tiledBands) { + if (generic.outerScheduleOptions != generic.intraTileScheduleOptions) { + scop->reschedule(outerBand->child({0}), generic.intraTileScheduleOptions); + } } + LOG_IF(INFO, FLAGS_debug_tc_mapper) + << "After intra-tile rescheduling:" << std::endl + << *mappedScop->schedule(); + // 1b. ...or after rescheduling if (!generic.proto.fix_parameters_before_scheduling()) { scop->specializeToContext(); @@ -998,30 +1006,37 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( mappedScop->insertMappingContext(); // 6. Map to threads - if (outerBand->numChildren() > 0) { - TC_CHECK_EQ(1u, outerBand->numChildren()); - // 6.1. Optionally detect reductions while mapping to threads + for (auto outerBand : tiledBands) { + if (outerBand->numChildren() > 0) { + TC_CHECK_EQ(1u, outerBand->numChildren()); - if (generic.proto.match_library_calls()) { - mappedScop->detectReductions(outerBand->child({0})); + // 6.1. Optionally detect reductions while mapping to threads + if (generic.proto.match_library_calls()) { + mappedScop->detectReductions(outerBand->child({0})); + } + auto child = outerBand->child({0}); + size_t numMappedInnerThreads = + mappedScop->mapInnermostBandsToThreads(child); + fixThreadsBelow(*mappedScop, outerBand, numMappedInnerThreads); } - auto child = outerBand->child({0}); - size_t numMappedInnerThreads = - mappedScop->mapInnermostBandsToThreads(child); - fixThreadsBelow(*mappedScop, outerBand, numMappedInnerThreads); - LOG_IF(INFO, FLAGS_debug_tc_mapper) - << "After mapping to threads:" << std::endl - << *mappedScop->schedule(); } + LOG_IF(INFO, FLAGS_debug_tc_mapper) + << "After mapping to threads:" << std::endl + << *mappedScop->schedule(); + // 7. Map to blocks - mappedScop->mapToBlocksAndScaleBand( - outerBand, generic.tiling.extractVector()); + for (auto outerBand : tiledBands) { + mappedScop->mapToBlocksAndScaleBand( + outerBand, generic.tiling.extractVector()); + } + LOG_IF(INFO, FLAGS_debug_tc_mapper) << "After mapping to blocks:" << std::endl << *mappedScop->schedule(); // 8. Promote to shared memory below the loops mapped to blocks. - // This may split the outer band, so find the new outer band after promotion. + // This may split the outer band, so find the new outer band after + // promotion. if (cudaOptions.proto().use_shared_memory()) { size_t sharedMemorySize = cudaOptions.proto().has_max_shared_memory() ? cudaOptions.proto().max_shared_memory() @@ -1040,29 +1055,38 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( sharedMemorySize -= reductionMemoryRequirement; } - auto band = outerBand->elemAs(); - LOG_IF(WARNING, FLAGS_debug_tc_mapper && band->nMember() == 0) - << "Aborting memory promotion because outer band has 0 members (NYI)"; - if (band->nMember() > 0 && sharedMemorySize > 0) { + if (sharedMemorySize > 0) { LOG_IF( WARNING, cudaOptions.proto().unroll_copy_shared() && !generic.proto.has_unroll()) << "requested to unroll copies to shared memory without providing the unroll size"; + bool unroll = cudaOptions.proto().unroll_copy_shared() && + generic.proto.has_unroll(); + for (auto outerBand : tiledBands) { + auto band = outerBand->elemAs(); + + LOG_IF(WARNING, FLAGS_debug_tc_mapper && band->nMember() == 0) + << "Aborting memory promotion for one band because it has 0 members (NYI)"; + if (band->nMember() == 0) { + continue; + } - promoteGreedilyAtDepth( - *mappedScop, - std::min(band->nOuterCoincident(), mappedScop->numBlocks.view.size()), - sharedMemorySize, - cudaOptions.proto().unroll_copy_shared() && - generic.proto.has_unroll()); + sharedMemorySize = promoteGreedilyAtDepth( + *mappedScop, + outerBand, + std::min( + band->nOuterCoincident(), mappedScop->numBlocks.view.size()), + sharedMemorySize, + cudaOptions.proto().unroll_copy_shared() && + generic.proto.has_unroll()); - auto bands = ScheduleTree::collectDFSPreorder( + /*auto bands = ScheduleTree::collectDFSPreorder( scop->scheduleRoot(), ScheduleTreeType::Band); - if (bands.size() == 0) { // Sanity check. - throw NoBandsException("no bands after promotion"); + if (bands.size() == 0) { // Sanity check. + throw NoBandsException("no bands after promotion"); + }*/ } - outerBand = bands[0]; } } diff --git a/tc/core/polyhedral/cuda/mapped_scop.h b/tc/core/polyhedral/cuda/mapped_scop.h index 55490596a..5051a629c 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.h +++ b/tc/core/polyhedral/cuda/mapped_scop.h @@ -215,6 +215,10 @@ class MappedScop { // Return a pointer to the split off tile. detail::ScheduleTree* splitOutReductionTileAndInsertSyncs( detail::ScheduleTree* band); + // Insert grid synchronization where needed. + // They are inserted in every sequence and below every sequential loop + // which are above the outer coincident bands. + void insertGridSyncs(const std::vector& outerBands); // Map "band" to thread identifiers using as many blockSizes values as outer // coincident dimensions (plus reduction dimension, if any), // insert synchronization in case of a reduction, and diff --git a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc index 4804fdb04..d3b76e77b 100644 --- a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc +++ b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc @@ -487,9 +487,10 @@ std::vector bandsSplitAfterDepth( * Only promote if the tensor elements referenced by the group are reused or * accessed in a non-coalesced way. */ -void promoteToSharedGreedy( +size_t promoteToSharedGreedy( Scop& scop, const Block& block, + detail::ScheduleTree* tree, size_t depth, size_t maxMemory) { using namespace tc::polyhedral::detail; @@ -503,7 +504,7 @@ void promoteToSharedGreedy( // 1. Collect all bands with a member located at the given depth in the // overall schedule. Make sure this is the last member of the band by // splitting off the subsequent members into a different band. - auto bands = bandsContainingScheduleDepth(root, depth); + auto bands = bandsContainingScheduleDepth(tree, depth); bands = bandsSplitAfterDepth(bands, root, depth); // 2. Compute full schedule without mapping filters. The filters would make @@ -600,20 +601,24 @@ void promoteToSharedGreedy( } scop.insertSyncsAroundCopies(bandNode); } + return remainingMemory; } } // namespace -void promoteGreedilyAtDepth( +size_t promoteGreedilyAtDepth( MappedScop& mscop, + detail::ScheduleTree* st, size_t depth, size_t sharedMemorySize, bool unrollCopies) { // 1. Promote using heuristic. - promoteToSharedGreedy( - mscop.scop(), mscop.numThreads, depth, sharedMemorySize); + sharedMemorySize = promoteToSharedGreedy( + mscop.scop(), mscop.numThreads, st, depth, sharedMemorySize); // 2. Map copies to shared, state by copy mapCopiesToThreads(mscop, unrollCopies); + + return sharedMemorySize; } /* diff --git a/tc/core/polyhedral/cuda/memory_promotion_heuristic.h b/tc/core/polyhedral/cuda/memory_promotion_heuristic.h index 508a3d8f6..b4de0df19 100644 --- a/tc/core/polyhedral/cuda/memory_promotion_heuristic.h +++ b/tc/core/polyhedral/cuda/memory_promotion_heuristic.h @@ -21,6 +21,9 @@ #include "tc/external/isl.h" namespace tc { + +class Block; + namespace polyhedral { class MappedScop; class Scop; @@ -36,8 +39,10 @@ class ScheduleTree; // "threadIdxXScheduleDepthState" contains the schedule depth at which the // computation was mapped to thread x and is used to check whether the global // memory is accessed in a coalesced way. -void promoteGreedilyAtDepth( +// Return the remaining memory. +size_t promoteGreedilyAtDepth( MappedScop& scop, + detail::ScheduleTree* st, std::size_t depth, std::size_t sharedMemorySize, bool unrollCopies); diff --git a/tc/core/polyhedral/schedule_tree.cc b/tc/core/polyhedral/schedule_tree.cc index d087539d9..5610efd26 100644 --- a/tc/core/polyhedral/schedule_tree.cc +++ b/tc/core/polyhedral/schedule_tree.cc @@ -187,6 +187,9 @@ vector ScheduleTree::positionRelativeTo( } size_t ScheduleTree::scheduleDepth(const ScheduleTree* relativeRoot) const { + if (relativeRoot == this) { + return 0; + } size_t depth = 0; for (auto const& anc : ancestors(relativeRoot)) { auto bandElem = anc->elemAs(); diff --git a/tc/core/polyhedral/scop.cc b/tc/core/polyhedral/scop.cc index f7cc8bc76..8d22101b3 100644 --- a/tc/core/polyhedral/scop.cc +++ b/tc/core/polyhedral/scop.cc @@ -453,6 +453,92 @@ detail::ScheduleTree* Scop::obtainOuterBand() { return tree; } +namespace { + +// Check if there is at least one band with a coincident dimension +// in the tree. +bool hasAtLeastOneCoincidentLoop(detail::ScheduleTree* tree) { + if (auto band = tree->elemAs()) { + if (find(band->coincident_.begin(), band->coincident_.end(), true) != + band->coincident_.end() && + band->permutable_) { + return true; + } + } + auto n = tree->numChildren(); + if (n == 0) { + return false; + } else if (n == 1) { + return hasAtLeastOneCoincidentLoop(tree->child({0})); + } else { + for (size_t i = 0; i < n; ++i) { + if (hasAtLeastOneCoincidentLoop(tree->child({i}))) { + return true; + } + } + } + return false; +} + +// Return the outermost coincident band. These are the bands with at least +// one coincident dimension, that are permutable, and that have no such +// coincident band as ancestors. Some of these bands are created with zero +// dimensions to ensure that the union of all children of all outermost +// coincident bands is equal to the leafs. Also, these created bands are +// put at the highest possible level of the tree, which also minimize the +// number of bands inserted. + +std::vector obtainOuterCoincidentBands( + detail::ScheduleTree* root, + detail::ScheduleTree* tree) { + auto nChildren = tree->numChildren(); + + // If there is no coincident loop, we create an outer band with no dimensions + if (not hasAtLeastOneCoincidentLoop(tree)) { + auto domain = root->elemAs(); + CHECK(domain); + auto band = ScheduleTree::makeEmptyBand(root); + if (nChildren == 0 || root == tree) { + return {setPermutable(insertNodeBelow(tree, std::move(band)))}; + } else { + return {setPermutable(insertNodeAbove(root, tree, std::move(band)))}; + } + } + + // If we find a coincident loop directly, we only have one outer band in tree. + while (nChildren == 1) { + if (auto band = tree->elemAs()) { + if (find(band->coincident_.begin(), band->coincident_.end(), true) != + band->coincident_.end() && + band->permutable_) { + return {tree}; + } + } + tree = tree->child({0}); + nChildren = tree->numChildren(); + } + + // If nChidren is null, that means that the current tree is a band with + // a coincident member. Otherwise, hasAtLeastOneCoincidentLoop would a + // return false + if (nChildren == 0) { + return {tree}; + } + + auto children = tree->children(); + std::vector outerBands = {}; + // Get the outer bands from the children, and add it to the + // list of outer bands. + for (size_t i = 0; i < nChildren; ++i) { + auto childOuterBand = obtainOuterCoincidentBands(root, tree->child({i})); + outerBands.insert( + outerBands.end(), childOuterBand.begin(), childOuterBand.end()); + } + return outerBands; +} + +} // namespace + detail::ScheduleTree* Scop::tileOuterBand(const TilingView& tileSizes) { using namespace tc::polyhedral::detail; auto band = obtainOuterBand(); @@ -467,6 +553,24 @@ detail::ScheduleTree* Scop::tileOuterBand(const TilingView& tileSizes) { return res; } +std::vector Scop::tileOuterCoincidentBands( + const TilingView& tileSizes) { + using namespace tc::polyhedral::detail; + auto bands = obtainOuterCoincidentBands(scheduleRoot(), scheduleRoot()); + std::vector sizes = tileSizes.extractVector(); + std::vector tiledBands = {}; + for (auto band : bands) { + auto bandNode = band->elemAs(); + if (bandNode->nMember() < sizes.size()) { + sizes.resize(bandNode->nMember()); + } + tiledBands.push_back(bandTile(band, sizes, TileOptions::ShiftPointLoops)); + } + LOG_IF(INFO, FLAGS_debug_tc_mapper) << "After tiling outer:" << std::endl + << *scheduleTreeUPtr; + return tiledBands; +} + void Scop::reschedule( ScheduleTree* tree, const SchedulerOptionsView& schedulerOptions) { @@ -483,8 +587,6 @@ void Scop::reschedule( auto newTree = computeSchedule(constraints, schedulerOptions); parentTree->detachChild(treePos); parentTree->insertChildren(treePos, newTree->detachChildren()); - LOG_IF(INFO, FLAGS_debug_tc_mapper) << "After rescheduling:" << std::endl - << *scheduleTreeUPtr; } const Halide::OutputImageParam& Scop::findArgument(isl::id id) const { diff --git a/tc/core/polyhedral/scop.h b/tc/core/polyhedral/scop.h index cd611bbe6..fd6f10e81 100644 --- a/tc/core/polyhedral/scop.h +++ b/tc/core/polyhedral/scop.h @@ -421,6 +421,13 @@ struct Scop { // tile loop band. detail::ScheduleTree* tileOuterBand(const TilingView& tiling); + // Tile the outermost coincident bands. See the obtainOuterCoincidentBands + // function to see what these bands correspond to. + // Splits the bands into tile loop band and point loop band where point loops + // have fixed trop counts specified in "tiling", and returns a pointer to + // the tile loop bands. + std::vector tileOuterCoincidentBands( + const TilingView& tiling); // Reschedule the schedule subtree rooted at "tree" with the // given scheduler options. void reschedule( diff --git a/test/test_cuda_mapper.cc b/test/test_cuda_mapper.cc index 3f43178ad..8b95b58eb 100644 --- a/test/test_cuda_mapper.cc +++ b/test/test_cuda_mapper.cc @@ -671,8 +671,10 @@ def fun() -> (O) { * an outer sequence. * Check that no synchronizations are inserted, since there is no * dependences between threads. + * + * The test is disabled since it is now possible to map the operation to blocks. */ -TEST_F(PolyhedralMapperTest, Copy2) { +TEST_F(PolyhedralMapperTest, DISABLED_Copy2) { auto tc = R"TC( def fun(float(N) I) -> (O1, O2) { O1(n) = I(n) diff --git a/test/test_cuda_mapper_memory_promotion.cc b/test/test_cuda_mapper_memory_promotion.cc index ee55a10c5..d215d2501 100644 --- a/test/test_cuda_mapper_memory_promotion.cc +++ b/test/test_cuda_mapper_memory_promotion.cc @@ -394,7 +394,8 @@ def fun(float(N, M) A) -> (B, C) { size_t maxSharedMemory) { auto mscop = prepareScop( tc, {{"N", problemSize1}, {"M", problemSize2}}, {tileSize1, tileSize2}); - promoteGreedilyAtDepth(*mscop, depth, maxSharedMemory, false); + promoteGreedilyAtDepth( + *mscop, mscop->scop().scheduleRoot(), depth, maxSharedMemory, false); return mscop; } }; From 75c693c8dea121e3f572f96a0fcee2895c65b388 Mon Sep 17 00:00:00 2001 From: math-fehr Date: Fri, 20 Apr 2018 16:33:55 +0200 Subject: [PATCH 03/16] Added grid synchronizations where needed. The grid synchronizations are added above the outermost coincident bands, in every sequence, and below every sequential band. --- tc/core/polyhedral/cuda/mapped_scop.cc | 55 ++++++++++++++++++++++++-- third-party/islpp | 2 +- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 0ded66556..5ffc59d9f 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -953,6 +953,49 @@ detail::ScheduleTree* MappedScop::splitOutReductionTileAndInsertSyncs( return child; } +namespace { +// Insert grid synchronizations where needed in st. +void insertGridSyncsDFS( + Scop& scop, + const std::vector& outerBands, + detail::ScheduleTree* st) { + using namespace polyhedral::detail; + // If the root of the schedule tree is an outermost coincident band, there is + // no synchronization left. + if (std::find(outerBands.begin(), outerBands.end(), st) != outerBands.end()) { + return; + } + auto nChildren = st->numChildren(); + auto children = st->children(); + for (size_t i = 0; i < nChildren; ++i) { + insertGridSyncsDFS(scop, outerBands, children[i]); + } + + // Insert synchronizations in sequences. + if (st->elemAs()) { + CHECK(nChildren); + if (hasOuterSequentialMember(scop.scheduleRoot(), st)) { + scop.insertSync(st, nChildren, Scop::SyncLevel::Grid); + } + for (size_t i = nChildren - 1; i > 0; --i) { + scop.insertSync(st, i, Scop::SyncLevel::Grid); + } + } + + // Insert synchronizations after sequential loops. + if (st->elemAs()) { + scop.insertSyncAfter(st, Scop::SyncLevel::Grid); + } +} +} // namespace + +void MappedScop::insertGridSyncs( + const std::vector& outerBands) { + using namespace polyhedral::detail; + + insertGridSyncsDFS(*scop_, outerBands, scop_->scheduleRoot()); +} + std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( std::unique_ptr&& scopUPtr, const CudaMappingOptions& cudaOptions) { @@ -975,7 +1018,7 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( // 2. Schedule scop = Scop::makeScheduled(*scop, generic.outerScheduleOptions); - // 3. Tile + // 3. Find and Tile outermost coincident bands... TC_CHECK_LT(0u, generic.tiling.size()) << "Must pass tile vector with >= 1 tile sizes"; std::vector tiledBands; @@ -1034,7 +1077,13 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( LOG_IF(INFO, FLAGS_debug_tc_mapper) << "After mapping to blocks:" << std::endl << *mappedScop->schedule(); - // 8. Promote to shared memory below the loops mapped to blocks. + // 8. Insert grid synchronization where needed. + mappedScop->insertGridSyncs(tiledBands); + LOG_IF(INFO, FLAGS_debug_tc_mapper) + << "After inserting grid synchronization:" << std::endl + << *mappedScop->schedule(); + + // 9. Promote to shared memory below the loops mapped to blocks. // This may split the outer band, so find the new outer band after // promotion. if (cudaOptions.proto().use_shared_memory()) { @@ -1090,7 +1139,7 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( } } - // 9. Promote to registers below the loops mapped to threads. + // 10. Promote to registers below the loops mapped to threads. if (cudaOptions.proto().use_private_memory()) { promoteToRegistersBelowThreads(*mappedScop, -1ull); } diff --git a/third-party/islpp b/third-party/islpp index 12a5b72fb..0cf707e50 160000 --- a/third-party/islpp +++ b/third-party/islpp @@ -1 +1 @@ -Subproject commit 12a5b72fb7c5fe4a841b691287ce4233a30f4b0f +Subproject commit 0cf707e5003570c9df45a0b5b776429b5f0ce298 From 1adbdd53a29ccf0a4184a1d44ff5457eda0be8d2 Mon Sep 17 00:00:00 2001 From: math-fehr Date: Mon, 23 Apr 2018 16:25:53 +0200 Subject: [PATCH 04/16] Added elements in mapping options to get resources per sm. Theses informations are needed to know if a cooperative kernel can be launched. Indeed, cooperative launch requires co-residency of all threads and blocks. --- tc/core/cuda/cuda.cc | 71 ++++++++++++++++++++++++++++++++-- tc/core/cuda/cuda.h | 21 +++++++++- tc/core/gpu.h | 47 ++++++++++++++++++++++ tc/proto/mapping_options.proto | 15 +++++++ 4 files changed, 148 insertions(+), 6 deletions(-) diff --git a/tc/core/cuda/cuda.cc b/tc/core/cuda/cuda.cc index 108e058cd..6fa492f74 100644 --- a/tc/core/cuda/cuda.cc +++ b/tc/core/cuda/cuda.cc @@ -30,7 +30,14 @@ DEFINE_bool(use_nvprof, false, "Start / stop nvprof"); namespace { -std::tuple, std::vector> init() { +std::tuple< + std::vector, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector> +init() { int deviceCount = 0; auto err_id = cudaGetDeviceCount(&deviceCount); if (err_id == 35 or err_id == 30) { @@ -44,14 +51,36 @@ std::tuple, std::vector> init() { } std::vector gpuNames; std::vector sharedMemSizes; + std::vector sharedMemSizesPerSM; + std::vector blocksPerSM; + std::vector threadsPerSM; + std::vector nbOfSM; gpuNames.reserve(deviceCount); for (int i = 0; i < deviceCount; ++i) { cudaDeviceProp deviceProp; TC_CUDA_RUNTIMEAPI_ENFORCE(cudaGetDeviceProperties(&deviceProp, i)); gpuNames.emplace_back(deviceProp.name); sharedMemSizes.emplace_back(deviceProp.sharedMemPerBlock); + sharedMemSizesPerSM.emplace_back(deviceProp.sharedMemPerMultiprocessor); + + // There is currently no way to get the number of blocks per sm + // with the CUDA api. The only relevant solution is to compute it + // with the compute capability. + // the formula works if the number of blocks per sm is nondecreasing after + // the 6.0 compute capability. + auto major = deviceProp.major; + blocksPerSM.emplace_back(major < 3 ? 8 : (major < 4 ? 16 : 32)); + + threadsPerSM.emplace_back(deviceProp.maxThreadsPerMultiProcessor); + nbOfSM.emplace_back(deviceProp.multiProcessorCount); } - return std::make_tuple(gpuNames, sharedMemSizes); + return std::make_tuple( + gpuNames, + sharedMemSizes, + sharedMemSizesPerSM, + blocksPerSM, + threadsPerSM, + nbOfSM); } } // namespace @@ -61,8 +90,13 @@ CudaGPUInfo& CudaGPUInfo::GPUInfo() { static thread_local bool inited = false; if (!inited) { auto infos = init(); - pInfo = std::unique_ptr( - new CudaGPUInfo(std::get<0>(infos), std::get<1>(infos))); + pInfo = std::unique_ptr(new CudaGPUInfo( + std::get<0>(infos), + std::get<1>(infos), + std::get<2>(infos), + std::get<3>(infos), + std::get<4>(infos), + std::get<5>(infos))); inited = true; } return *pInfo; @@ -102,4 +136,33 @@ size_t CudaGPUInfo::SharedMemorySize() const { } return sharedMemSizes_.at(CurrentGPUId()); } + +size_t CudaGPUInfo::SharedMemorySizePerSM() const { + if (NumberGPUs() == 0) { + return 0; // no shared memory per sm if no GPUs + } + return sharedMemSizesPerSM_.at(CurrentGPUId()); +} + +size_t CudaGPUInfo::BlocksPerSM() const { + if (NumberGPUs() == 0) { + return 0; // no blocks per sm if no GPUs + } + return blocksPerSM_.at(CurrentGPUId()); +} + +size_t CudaGPUInfo::ThreadsPerSM() const { + if (NumberGPUs() == 0) { + return 0; // no threads per sm if no GPUs + } + return threadsPerSM_.at(CurrentGPUId()); +} + +size_t CudaGPUInfo::NbOfSM() const { + if (NumberGPUs() == 0) { + return 0; // no sm if no GPUs + } + return nbOfSM_.at(CurrentGPUId()); +} + } // namespace tc diff --git a/tc/core/cuda/cuda.h b/tc/core/cuda/cuda.h index 70406f6a4..30d9fbe65 100644 --- a/tc/core/cuda/cuda.h +++ b/tc/core/cuda/cuda.h @@ -96,8 +96,17 @@ struct WithCudaDevice { class CudaGPUInfo { CudaGPUInfo( const std::vector& gpuNames, - const std::vector& sharedMemSizes) - : gpuNames_(gpuNames), sharedMemSizes_(sharedMemSizes) {} + const std::vector& sharedMemSizes, + const std::vector& sharedMemSizesPerSM, + const std::vector& blocksPerSM, + const std::vector& threadsPerSM, + const std::vector& nbOfSM) + : gpuNames_(gpuNames), + sharedMemSizes_(sharedMemSizes), + sharedMemSizesPerSM_(sharedMemSizesPerSM), + blocksPerSM_(blocksPerSM), + threadsPerSM_(threadsPerSM), + nbOfSM_(nbOfSM) {} public: static CudaGPUInfo& GPUInfo(); @@ -110,9 +119,17 @@ class CudaGPUInfo { std::string GetGPUName(int id = -1) const; std::string getCudaDeviceStr() const; size_t SharedMemorySize() const; + size_t SharedMemorySizePerSM() const; + size_t BlocksPerSM() const; + size_t ThreadsPerSM() const; + size_t NbOfSM() const; std::vector gpuNames_; std::vector sharedMemSizes_; + std::vector sharedMemSizesPerSM_; + std::vector blocksPerSM_; + std::vector threadsPerSM_; + std::vector nbOfSM_; }; struct CudaProfiler { diff --git a/tc/core/gpu.h b/tc/core/gpu.h index cdacc05a3..1e9328f9f 100644 --- a/tc/core/gpu.h +++ b/tc/core/gpu.h @@ -34,4 +34,51 @@ inline size_t querySharedMemorySize() { #endif } +/// Get the shared memory size per sm of the GPU device active in the current +/// thread. +/// The call is forwarded to the appropriate GPU driver (CUDA in particular). +/// If a thread has no associated GPU device, return 0. +inline size_t querySharedMemorySizePerSM() { +#ifdef CUDA_HOME + return CudaGPUInfo::GPUInfo().SharedMemorySizePerSM(); +#else + return 0; +#endif +} + +/// Get the maximum number of blocks per sm of the GPU device active +/// in the current thread. +/// The call is forwarded to the appropriate GPU driver (CUDA in particular). +/// If a thread has no associated GPU device, return 0. +inline size_t queryBlocksPerSM() { +#ifdef CUDA_HOME + return CudaGPUInfo::GPUInfo().BlocksPerSM(); +#else + return 0; +#endif +} + +/// Get the maximum number of threads per sm of the GPU device active +/// in the current thread. +/// The call is forwarded to the appropriate GPU driver (CUDA in particular). +/// If a thread has no associated GPU device, return 0. +inline size_t queryThreadsPerSM() { +#ifdef CUDA_HOME + return CudaGPUInfo::GPUInfo().ThreadsPerSM(); +#else + return 0; +#endif +} + +/// Get the number of sm on the GPU device active in the current thread. +/// The call is forwarded to the appropriate GPU driver (CUDA in particular). +/// If a thread has no associated GPU device, return 0. +inline size_t queryNbOfSM() { +#ifdef CUDA_HOME + return CudaGPUInfo::GPUInfo().NbOfSM(); +#else + return 0; +#endif +} + } // namespace tc diff --git a/tc/proto/mapping_options.proto b/tc/proto/mapping_options.proto index ff29e3557..c413f1b8c 100644 --- a/tc/proto/mapping_options.proto +++ b/tc/proto/mapping_options.proto @@ -70,6 +70,21 @@ message CudaMappingOptionsProto { optional uint64 max_shared_memory = 7; // Use the readonly cache (i.e. emit __ldg loads) required bool use_readonly_cache = 8; + // Maximum number of blocks per streaming multiprocessor to use. If not + // provided, the number of blocks per sm on the current active device will + // be used. + optional uint64 max_blocks_per_sm = 9; + // Maximum number of threads per streaming multiprocessor to use. If not + // provided, the number of threads per sm on the current active device will + // be used. + optional uint64 max_threads_per_sm = 10; + // Maximum number of shared memory per streaming multiprocessor to use, in + // bytes. If not provided, all shared memory available on a sm on the current + // active device will be used. + optional uint64 max_shared_memory_per_sm = 11; + // The number of multiprocessors that can be used. If not provided, all + // multiprocessors will be used. + optional uint64 nb_of_sm = 12; } message CpuMappingOptionsProto { From 892bbf72f49a8919736395bcfc7e579cbf63bae8 Mon Sep 17 00:00:00 2001 From: math-fehr Date: Wed, 25 Apr 2018 14:24:34 +0200 Subject: [PATCH 05/16] Use a boolean to know if there is a grid synchronization for rtc. --- tc/core/cuda/cuda_backend.h | 1 + tc/core/cuda/cuda_rtc.cc | 10 +++-- tc/core/cuda/cuda_rtc.h | 6 +-- tc/core/cuda/cuda_tc_executor.cc | 15 +++++-- tc/core/cuda/cuda_tc_executor.h | 1 + tc/core/polyhedral/cuda/mapped_scop.cc | 55 +++++++++++++++++++++++--- tc/core/polyhedral/cuda/mapped_scop.h | 17 ++++++-- test/test_cuda_mapper.cc | 2 + 8 files changed, 86 insertions(+), 21 deletions(-) diff --git a/tc/core/cuda/cuda_backend.h b/tc/core/cuda/cuda_backend.h index f2cfb239a..3faf2dcc9 100644 --- a/tc/core/cuda/cuda_backend.h +++ b/tc/core/cuda/cuda_backend.h @@ -37,6 +37,7 @@ struct CudaCompilationResult { std::vector parameters; Grid grid; Block block; + bool useGridSync; }; /** diff --git a/tc/core/cuda/cuda_rtc.cc b/tc/core/cuda/cuda_rtc.cc index 0b959a4a2..31407319f 100644 --- a/tc/core/cuda/cuda_rtc.cc +++ b/tc/core/cuda/cuda_rtc.cc @@ -51,7 +51,8 @@ void CudaRTCFunction::clear() { std::unique_ptr CudaRTCFunction::Compile( const std::string& name, - const std::string& source) { + const std::string& source, + bool useGridSync) { std::unique_ptr res(new CudaRTCFunction()); res->specializedName = name; res->cleared_ = false; @@ -89,7 +90,7 @@ std::unique_ptr CudaRTCFunction::Compile( "-DNVRTC_CUB=1", cudaHome.c_str(), cubHome.c_str()}; - if (FLAGS_grid_sync) { + if (useGridSync) { nvrtcts.push_back("--relocatable-device-code=true"); } if (FLAGS_debug_cuda) { @@ -136,6 +137,7 @@ std::ostream& operator<<(std::ostream& os, const std::array& a) { Duration CudaRTCFunction::Launch( const std::array& grid, const std::array& block, + bool useGridSync, unsigned int shared_mem, cudaStream_t stream, std::vector params, @@ -147,7 +149,7 @@ Duration CudaRTCFunction::Launch( if (perGpuModule_.count(dev) == 0) { CUmodule module; CUfunction function; - if (FLAGS_grid_sync) { + if (useGridSync) { CUlinkState linkState; TC_CUDA_DRIVERAPI_ENFORCE(cuLinkCreate(0, 0, 0, &linkState)); TC_CUDA_DRIVERAPI_ENFORCE(cuLinkAddFile( @@ -198,7 +200,7 @@ Duration CudaRTCFunction::Launch( unsigned int by = block[1]; unsigned int bz = block[2]; auto launch = [&]() { - if (FLAGS_grid_sync) { + if (useGridSync) { TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchCooperativeKernel( perGpuKernel_.at(dev), gx, diff --git a/tc/core/cuda/cuda_rtc.h b/tc/core/cuda/cuda_rtc.h index 4aa5e831f..25f46ee0d 100644 --- a/tc/core/cuda/cuda_rtc.h +++ b/tc/core/cuda/cuda_rtc.h @@ -41,14 +41,14 @@ class CudaRTCFunction { public: ~CudaRTCFunction(); - static std::unique_ptr Compile( - const std::string& name, - const std::string& source); + static std::unique_ptr + Compile(const std::string& name, const std::string& source, bool useGridSync); // if profile is set it returns the kernel runtime Duration Launch( const std::array& grid, const std::array& block, + bool useGridSync, unsigned int shared_mem, cudaStream_t stream, // by copy because we take an address to element when calling the kernel diff --git a/tc/core/cuda/cuda_tc_executor.cc b/tc/core/cuda/cuda_tc_executor.cc index cf27d5f29..e35d1b867 100644 --- a/tc/core/cuda/cuda_tc_executor.cc +++ b/tc/core/cuda/cuda_tc_executor.cc @@ -57,10 +57,13 @@ CudaTcExecutor::CudaTcExecutor( // force unloading in case we JIT with the same name/input/outputs with // different options. this->clearRuntimeCompiledFunction(); - rtcFun_ = CudaRTCFunction::Compile( - compilationResult.specializedName, compilationResult.source); grid_ = compilationResult.grid; block_ = compilationResult.block; + useGridSync_ = compilationResult.useGridSync; + rtcFun_ = CudaRTCFunction::Compile( + compilationResult.specializedName, + compilationResult.source, + useGridSync_); auto t1 = std::chrono::high_resolution_clock::now(); LOG_IF(INFO, FLAGS_debug_tc_mapper) << "[COMPILE] Compiling with host JIT compiler took: " @@ -100,12 +103,14 @@ CudaCompilationResult CudaBackend::compileWithTcMapper( std::string source; Grid grid; Block block; - std::tie(source, grid, block) = mappedScop->codegen(specializedName); + bool useGridSync; + std::tie(source, grid, block, useGridSync) = + mappedScop->codegen(specializedName); LOG_IF(INFO, FLAGS_dump_cuda) << "generatedCuda: " << source << "\n" << "grid: " << grid << " block: " << block; return CudaCompilationResult{ - source, specializedName, parameters, grid, block}; + source, specializedName, parameters, grid, block, useGridSync}; } void CudaTcExecutor::uncheckedRun( @@ -118,6 +123,7 @@ void CudaTcExecutor::uncheckedRun( rtcFun_->Launch( grid_.view.extractDefaultedArray(), block_.view.extractDefaultedArray(), + useGridSync_, 0, info.stream, parameters_, @@ -136,6 +142,7 @@ ProfilingInfo CudaTcExecutor::profileUnchecked( Duration kernelRuntime(rtcFun_->Launch( grid_.view.extractDefaultedArray(), block_.view.extractDefaultedArray(), + useGridSync_, 0, stream, parameters_, diff --git a/tc/core/cuda/cuda_tc_executor.h b/tc/core/cuda/cuda_tc_executor.h index a497c940d..bb0b7a911 100644 --- a/tc/core/cuda/cuda_tc_executor.h +++ b/tc/core/cuda/cuda_tc_executor.h @@ -63,5 +63,6 @@ class CudaTcExecutor : public TcExecutor { // GPU-specific results of compilation Grid grid_; Block block_; + bool useGridSync_; }; } // namespace tc diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 5ffc59d9f..521134309 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -882,6 +882,7 @@ std::unique_ptr makeSpecializedMappedScop( std::move(scop), grid, block, + mappedScop.useGridSync, mappedScop.unroll, mappedScop.useReadOnlyCache); res->insertMappingContext(); @@ -898,7 +899,7 @@ std::unique_ptr makeSpecializedMappedScop( // Before generating code, make a copy of the scop and insert // the context of the original scop as top-level // context node in schedule tree. -std::tuple MappedScop::codegen( +std::tuple MappedScop::codegen( const std::string& specializedName) const { validate(schedule()); @@ -922,7 +923,8 @@ std::tuple MappedScop::codegen( return std::make_tuple( code.str(), mappedScopForCodegen->numBlocks, - mappedScopForCodegen->numThreads); + mappedScopForCodegen->numThreads, + mappedScopForCodegen->useGridSync); } // Split out a single reduction tile (in the directions other than @@ -1001,11 +1003,53 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( const CudaMappingOptions& cudaOptions) { using namespace polyhedral::detail; + // Query the relevant cuda device information + auto nbBlocks = 1; + for (size_t i = 0; i < cudaOptions.grid.size(); i++) { + nbBlocks *= cudaOptions.grid[i]; + } + auto nbThreads = 1; + for (size_t i = 0; i < cudaOptions.block.size(); i++) { + nbThreads *= cudaOptions.block[i]; + } + size_t sharedMemorySize = cudaOptions.proto().has_max_shared_memory() + ? cudaOptions.proto().max_shared_memory() + : querySharedMemorySize(); + size_t maxBlocksPerSM = cudaOptions.proto().has_max_blocks_per_sm() + ? cudaOptions.proto().max_blocks_per_sm() + : queryBlocksPerSM(); + size_t maxThreadsPerSM = cudaOptions.proto().has_max_threads_per_sm() + ? cudaOptions.proto().max_threads_per_sm() + : queryThreadsPerSM(); + size_t sharedMemorySizePerSM = + cudaOptions.proto().has_max_shared_memory_per_sm() + ? cudaOptions.proto().max_shared_memory_per_sm() + : querySharedMemorySizePerSM(); + size_t nbOfSM = cudaOptions.proto().has_nb_of_sm() + ? cudaOptions.proto().nb_of_sm() + : queryNbOfSM(); + auto blocksPerSM = nbOfSM == 0 ? 0 : ((nbBlocks + nbOfSM - 1) / nbOfSM); + auto threadsPerSM = nbThreads * blocksPerSM; + + bool useGridSync = FLAGS_grid_sync; + if (useGridSync) { + useGridSync &= nbOfSM * maxBlocksPerSM >= nbBlocks; + useGridSync &= maxThreadsPerSM > threadsPerSM; + } + + if (useGridSync) { + LOG(WARNING) << "Use grid sync" << std::endl; + } + if (FLAGS_grid_sync && !useGridSync) { + LOG(WARNING) << "Can't use grid sync" << std::endl; + } + const auto& generic = cudaOptions.generic; auto mappedScop = std::unique_ptr(new MappedScop( std::move(scopUPtr), ::tc::Grid(cudaOptions.grid), ::tc::Block(cudaOptions.block), + useGridSync, generic.proto.unroll(), cudaOptions.proto().use_readonly_cache())); auto& scop = mappedScop->scop_; @@ -1022,8 +1066,10 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( TC_CHECK_LT(0u, generic.tiling.size()) << "Must pass tile vector with >= 1 tile sizes"; std::vector tiledBands; - if (FLAGS_grid_sync) { + if (useGridSync) { tiledBands = scop->tileOuterCoincidentBands(generic.tiling); + sharedMemorySize = + std::min(sharedMemorySize, sharedMemorySizePerSM / blocksPerSM); } else { tiledBands = {scop->tileOuterBand(generic.tiling)}; } @@ -1087,9 +1133,6 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( // This may split the outer band, so find the new outer band after // promotion. if (cudaOptions.proto().use_shared_memory()) { - size_t sharedMemorySize = cudaOptions.proto().has_max_shared_memory() - ? cudaOptions.proto().max_shared_memory() - : querySharedMemorySize(); // If reductions found, their synchronization requires an opaque cache in // shared memory. Subtract 4k from available shared memory for each // reduction found, this is hack based on each thread of max 1024 in the diff --git a/tc/core/polyhedral/cuda/mapped_scop.h b/tc/core/polyhedral/cuda/mapped_scop.h index 5051a629c..c8e845173 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.h +++ b/tc/core/polyhedral/cuda/mapped_scop.h @@ -61,11 +61,13 @@ class MappedScop { std::unique_ptr&& scop, ::tc::Grid grid, ::tc::Block block, + bool useGridSync_, uint64_t unroll_, bool useReadOnlyCache_) : scop_(std::move(scop)), numBlocks(grid), numThreads(block), + useGridSync(useGridSync_), unroll(unroll_), useReadOnlyCache(useReadOnlyCache_) {} @@ -73,7 +75,12 @@ class MappedScop { static inline std::unique_ptr makeOneBlockOneThread( std::unique_ptr&& scop) { auto mscop = std::unique_ptr(new MappedScop( - std::move(scop), ::tc::Grid{1, 1, 1}, ::tc::Block{1, 1, 1}, 1, false)); + std::move(scop), + ::tc::Grid{1, 1, 1}, + ::tc::Block{1, 1, 1}, + false, + 1, + false)); auto band = mscop->scop_->obtainOuterBand(); mscop->mapBlocksForward(band, 0); mscop->mapThreadsBackward(band); @@ -86,10 +93,11 @@ class MappedScop { std::unique_ptr&& scop, ::tc::Grid grid, ::tc::Block block, + bool useGridSync, uint64_t unroll, bool useReadOnlyCache) { - return std::unique_ptr( - new MappedScop(std::move(scop), grid, block, unroll, useReadOnlyCache)); + return std::unique_ptr(new MappedScop( + std::move(scop), grid, block, useGridSync, unroll, useReadOnlyCache)); } // Apply the hand-written OuterBlockInnerThread mapping strategy. @@ -121,7 +129,7 @@ class MappedScop { // Generate CUDA code at the current state of transformation provided a // name for the generated function. - std::tuple codegen( + std::tuple codegen( const std::string& specializedName) const; // Accessors.. @@ -237,6 +245,7 @@ class MappedScop { public: const ::tc::Grid numBlocks; const ::tc::Block numThreads; + bool useGridSync; const uint64_t unroll; const bool useReadOnlyCache; diff --git a/test/test_cuda_mapper.cc b/test/test_cuda_mapper.cc index 8b95b58eb..00cb28246 100644 --- a/test/test_cuda_mapper.cc +++ b/test/test_cuda_mapper.cc @@ -90,6 +90,7 @@ struct PolyhedralMapperTest : public ::testing::Test { std::move(scop), Grid{1}, Block{blockSizes[0], blockSizes[1]}, + false, 0, false); auto band = mscop->mapBlocksForward(root->child({0}), 1); @@ -113,6 +114,7 @@ struct PolyhedralMapperTest : public ::testing::Test { std::move(scop), Grid{gridSizes[0], gridSizes[1]}, Block{blockSizes[0], blockSizes[1]}, + false, 0, false); From 50530173d1ab0d0277d3a01165de95b3922ad68d Mon Sep 17 00:00:00 2001 From: math-fehr Date: Mon, 23 Apr 2018 14:02:07 +0200 Subject: [PATCH 06/16] Added test to check that grid synchronization is inserted. --- test/test_cuda_mapper.cc | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/test/test_cuda_mapper.cc b/test/test_cuda_mapper.cc index 00cb28246..a632d8f9b 100644 --- a/test/test_cuda_mapper.cc +++ b/test/test_cuda_mapper.cc @@ -668,15 +668,9 @@ def fun() -> (O) { /* * Check that a schedule tree without a single outer band gets mapped - * properly, i.e., that a band is inserted above the branching. - * Use the minimal fusion strategy to ensure the scheduler produces - * an outer sequence. - * Check that no synchronizations are inserted, since there is no - * dependences between threads. - * - * The test is disabled since it is now possible to map the operation to blocks. + * on blocks, and has a grid synchronization. */ -TEST_F(PolyhedralMapperTest, DISABLED_Copy2) { +TEST_F(PolyhedralMapperTest, Copy2) { auto tc = R"TC( def fun(float(N) I) -> (O1, O2) { O1(n) = I(n) @@ -686,14 +680,9 @@ def fun(float(N) I) -> (O1, O2) { auto mappingOptions = DefaultOptions(); mappingOptions.scheduleFusionStrategy(FusionStrategy::Min); auto code = codegenMapped(tc, mappingOptions); - auto loop = "for (int c0 = t0; c0 < N; c0 += 32)"; - auto blockSync = "__syncthreads();"; - auto pos1 = code.find(loop); - auto pos2 = code.find(loop, pos1 + 1); - auto pos3 = code.find(blockSync); - EXPECT_TRUE(pos1 != std::string::npos); - EXPECT_TRUE(pos2 != std::string::npos); - EXPECT_TRUE(pos3 == std::string::npos); + auto gridSync = "__syncgrid()"; + auto pos = code.find(gridSync); + EXPECT_TRUE(pos != std::string::npos); } /* From 48eeb661a5a5dadab2f46d5ebe54757e948c6d1d Mon Sep 17 00:00:00 2001 From: math-fehr Date: Thu, 26 Apr 2018 16:27:02 +0200 Subject: [PATCH 07/16] Added autotuner tests that behave differently with grid sync. --- test/cuda/test_autotuner.cc | 56 +++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/test/cuda/test_autotuner.cc b/test/cuda/test_autotuner.cc index 71e36a30a..02057400e 100644 --- a/test/cuda/test_autotuner.cc +++ b/test/cuda/test_autotuner.cc @@ -178,6 +178,62 @@ def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) { benchmarkKernelOptions(TC, name, inputs, bestOptions); } +TEST_F(ATenCompilationUnitTest, BatchNorm) { + static constexpr uint32_t N = 32, C = 4, H = 56, W = 56; + at::Tensor I = at::CUDA(at::kFloat).rand({N, C, H, W}); + at::Tensor rMeanIn = at::CUDA(at::kFloat).rand({C}); + at::Tensor rVarIn = at::CUDA(at::kFloat).rand({C}); + std::vector inputs = {I, rMeanIn, rVarIn}; + + // This may not be correct + static constexpr auto TC = R"TC( + def batchnorm(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) { + mean(c) +=! I(nn, c, hh, ww) + mean(c) = mean(c) / (N * H * W) + rMeanOut(c) = (1 - 0.2) * rMeanIn(c) + 0.2 * mean(c) + centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) + variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) + expectedVariance(c) +=! (variance(n, c, h, w) + 0.001) / (N * H * W) + rVarOut(c) = rsqrt((1 - 0.2) * rVarIn(c) + 0.2 * expectedVariance(c)) + O(n, c, h, w) = centered(n, c, h, w) / rVarOut(c) + normalizedOut(n, c, h, w) = O(n, c, h, w) + } +)TC"; + + auto name = "batchnorm"; + auto options = tc::CudaMappingOptions::makeNaiveCudaMappingOptions(); + + std::string cacheFilename = ""; + auto bestOptions = + autotune(cacheFilename, TC, name, inputs, options, {options}); +} + +TEST_F(ATenCompilationUnitTest, GroupConvolution) { + uint32_t N = 32, G = 32, C = 16, F = 16; + uint32_t W = 14, H = 14, KW = 3, KH = 3; + at::Tensor I = at::CUDA(at::kFloat).rand({N, G, C, H, W}); + at::Tensor W1 = at::CUDA(at::kFloat).rand({G, F, C, KH, KW}); + at::Tensor B = at::CUDA(at::kFloat).rand({G, F}); + std::vector inputs = {I, W1, B}; + + static constexpr auto TC = R"( +def group_convolution(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(G,F) B) +-> (O) +{ + O(n, g, f, h, w) +=! + I(n, g, r_c, h + r_kh, w + r_kw) * W1(g, f, r_c, r_kh, r_kw) + O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) +} +)"; + + auto name = "group_convolution"; + auto options = tc::CudaMappingOptions::makeNaiveCudaMappingOptions(); + + std::string cacheFilename = ""; + auto bestOptions = + autotune(cacheFilename, TC, name, inputs, options, {options}); +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true); From 2eda8eee63c110dea9e8704babd7f8f012e604c2 Mon Sep 17 00:00:00 2001 From: math-fehr Date: Fri, 11 May 2018 18:06:06 +0200 Subject: [PATCH 08/16] Refactored memory promotion function to be called only once. --- tc/core/polyhedral/cuda/mapped_scop.cc | 42 ++++++++++--------- .../cuda/memory_promotion_heuristic.cc | 26 ++++++++---- .../cuda/memory_promotion_heuristic.h | 4 +- tc/core/polyhedral/scop.cc | 3 ++ test/test_cuda_mapper_memory_promotion.cc | 6 ++- 5 files changed, 50 insertions(+), 31 deletions(-) diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 521134309..6a242302b 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -1155,30 +1155,34 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( << "requested to unroll copies to shared memory without providing the unroll size"; bool unroll = cudaOptions.proto().unroll_copy_shared() && generic.proto.has_unroll(); - for (auto outerBand : tiledBands) { - auto band = outerBand->elemAs(); - LOG_IF(WARNING, FLAGS_debug_tc_mapper && band->nMember() == 0) + std::vector depths; + std::vector bandsWithPromotion; + for (auto band : tiledBands) { + auto bandElem = band->elemAs(); + LOG_IF(WARNING, FLAGS_debug_tc_mapper && bandElem->nMember() == 0) << "Aborting memory promotion for one band because it has 0 members (NYI)"; - if (band->nMember() == 0) { + if (bandElem->nMember() == 0) { continue; } - - sharedMemorySize = promoteGreedilyAtDepth( - *mappedScop, - outerBand, - std::min( - band->nOuterCoincident(), mappedScop->numBlocks.view.size()), - sharedMemorySize, - cudaOptions.proto().unroll_copy_shared() && - generic.proto.has_unroll()); - - /*auto bands = ScheduleTree::collectDFSPreorder( - scop->scheduleRoot(), ScheduleTreeType::Band); - if (bands.size() == 0) { // Sanity check. - throw NoBandsException("no bands after promotion"); - }*/ + bandsWithPromotion.push_back(band); + depths.push_back(std::min( + bandElem->nOuterCoincident(), mappedScop->numBlocks.view.size())); } + + sharedMemorySize = promoteGreedilyAtDepth( + *mappedScop, + bandsWithPromotion, + depths, + sharedMemorySize, + cudaOptions.proto().unroll_copy_shared() && + generic.proto.has_unroll()); + + /*auto bands = ScheduleTree::collectDFSPreorder( + scop->scheduleRoot(), ScheduleTreeType::Band); + if (bands.size() == 0) { // Sanity check. + throw NoBandsException("no bands after promotion"); + }*/ } } diff --git a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc index d3b76e77b..d153ef08e 100644 --- a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc +++ b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc @@ -490,13 +490,15 @@ std::vector bandsSplitAfterDepth( size_t promoteToSharedGreedy( Scop& scop, const Block& block, - detail::ScheduleTree* tree, - size_t depth, + std::vector trees, + std::vector depths, size_t maxMemory) { using namespace tc::polyhedral::detail; - if (depth == 0) { - throw promotion::PromotionNYI("promotion before any band"); + for (auto depth : depths) { + if (depth == 0) { + throw promotion::PromotionNYI("promotion before any band"); + } } auto root = scop.scheduleRoot(); @@ -504,8 +506,12 @@ size_t promoteToSharedGreedy( // 1. Collect all bands with a member located at the given depth in the // overall schedule. Make sure this is the last member of the band by // splitting off the subsequent members into a different band. - auto bands = bandsContainingScheduleDepth(tree, depth); - bands = bandsSplitAfterDepth(bands, root, depth); + std::vector bands; + for (size_t i = 0; i < trees.size(); ++i) { + auto treeBands = bandsContainingScheduleDepth(trees[i], depths[i]); + treeBands = bandsSplitAfterDepth(treeBands, root, depths[i]); + bands.insert(bands.end(), treeBands.begin(), treeBands.end()); + } // 2. Compute full schedule without mapping filters. The filters would make // it impossible to test for coalescing by incrementing a member of a band as @@ -599,6 +605,8 @@ size_t promoteToSharedGreedy( remainingMemory -= memoryRequirement; } } + LOG(WARNING) << *root; + LOG(WARNING) << *bandNode; scop.insertSyncsAroundCopies(bandNode); } return remainingMemory; @@ -607,13 +615,13 @@ size_t promoteToSharedGreedy( size_t promoteGreedilyAtDepth( MappedScop& mscop, - detail::ScheduleTree* st, - size_t depth, + std::vector trees, + std::vector depths, size_t sharedMemorySize, bool unrollCopies) { // 1. Promote using heuristic. sharedMemorySize = promoteToSharedGreedy( - mscop.scop(), mscop.numThreads, st, depth, sharedMemorySize); + mscop.scop(), mscop.numThreads, trees, depths, sharedMemorySize); // 2. Map copies to shared, state by copy mapCopiesToThreads(mscop, unrollCopies); diff --git a/tc/core/polyhedral/cuda/memory_promotion_heuristic.h b/tc/core/polyhedral/cuda/memory_promotion_heuristic.h index b4de0df19..4fb08fd89 100644 --- a/tc/core/polyhedral/cuda/memory_promotion_heuristic.h +++ b/tc/core/polyhedral/cuda/memory_promotion_heuristic.h @@ -42,8 +42,8 @@ class ScheduleTree; // Return the remaining memory. size_t promoteGreedilyAtDepth( MappedScop& scop, - detail::ScheduleTree* st, - std::size_t depth, + std::vector trees, + std::vector depths, std::size_t sharedMemorySize, bool unrollCopies); diff --git a/tc/core/polyhedral/scop.cc b/tc/core/polyhedral/scop.cc index 8d22101b3..cd9e3816a 100644 --- a/tc/core/polyhedral/scop.cc +++ b/tc/core/polyhedral/scop.cc @@ -208,6 +208,9 @@ void Scop::promoteGroup( void Scop::insertSyncsAroundCopies(ScheduleTree* tree) { // Return immediately if nothing was inserted + if (tree->numChildren() == 0) { + return; + } auto extensionNode = tree->child({0})->elemAs(); if (!extensionNode) { diff --git a/test/test_cuda_mapper_memory_promotion.cc b/test/test_cuda_mapper_memory_promotion.cc index d215d2501..1691d96ab 100644 --- a/test/test_cuda_mapper_memory_promotion.cc +++ b/test/test_cuda_mapper_memory_promotion.cc @@ -395,7 +395,11 @@ def fun(float(N, M) A) -> (B, C) { auto mscop = prepareScop( tc, {{"N", problemSize1}, {"M", problemSize2}}, {tileSize1, tileSize2}); promoteGreedilyAtDepth( - *mscop, mscop->scop().scheduleRoot(), depth, maxSharedMemory, false); + *mscop, + {mscop->scop().scheduleRoot()}, + {depth}, + maxSharedMemory, + false); return mscop; } }; From 06b672735cbc942c6a7a8c7bff5f9895138d99cb Mon Sep 17 00:00:00 2001 From: math-fehr Date: Mon, 7 May 2018 14:57:48 +0200 Subject: [PATCH 09/16] Modified batchnorm in autotuner. --- test/cuda/test_autotuner.cc | 47 ++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/test/cuda/test_autotuner.cc b/test/cuda/test_autotuner.cc index 02057400e..3ad750b41 100644 --- a/test/cuda/test_autotuner.cc +++ b/test/cuda/test_autotuner.cc @@ -179,33 +179,42 @@ def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) { } TEST_F(ATenCompilationUnitTest, BatchNorm) { - static constexpr uint32_t N = 32, C = 4, H = 56, W = 56; + static constexpr uint32_t N = 32, C = 4, H = 256, W = 256; at::Tensor I = at::CUDA(at::kFloat).rand({N, C, H, W}); at::Tensor rMeanIn = at::CUDA(at::kFloat).rand({C}); at::Tensor rVarIn = at::CUDA(at::kFloat).rand({C}); - std::vector inputs = {I, rMeanIn, rVarIn}; + at::Tensor eps = at::CUDA(at::kFloat).rand({1}); + eps[0] = 1.0f; + at::Tensor momentum = at::CUDA(at::kFloat).rand({1}); + momentum[0] = 1.0; + + std::vector inputs = {momentum, eps, I, rMeanIn, rVarIn}; - // This may not be correct static constexpr auto TC = R"TC( - def batchnorm(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) { - mean(c) +=! I(nn, c, hh, ww) - mean(c) = mean(c) / (N * H * W) - rMeanOut(c) = (1 - 0.2) * rMeanIn(c) + 0.2 * mean(c) - centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) - variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) - expectedVariance(c) +=! (variance(n, c, h, w) + 0.001) / (N * H * W) - rVarOut(c) = rsqrt((1 - 0.2) * rVarIn(c) + 0.2 * expectedVariance(c)) - O(n, c, h, w) = centered(n, c, h, w) / rVarOut(c) - normalizedOut(n, c, h, w) = O(n, c, h, w) - } -)TC"; +def batchnorm( + float(1) momentum, float(1) eps, + float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) +-> (normalizedOut, rMeanOut, rVarOut, mean, centered, expectedVariance) +{ + mean(c) +=! I(r_n, c, r_h, r_w) + mean(c) = mean(c) / (N * H * W) + rMeanOut(c) = (1 - momentum(0)) * rMeanIn(c) + momentum(0) * mean(c) + centered(n, c, h, w) = I( n, c, h, w) - rMeanOut(c) + expectedVariance(c) +=! centered( n, c, h, w) * centered(n, c, h, w) + expectedVariance(c) = expectedVariance(c) / (N * H * W) + eps(0) + rVarOut(c) = rsqrt( + (1 - momentum(0)) * rVarIn(c) + + momentum(0) * expectedVariance(c)) + normalizedOut(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) + })TC"; + auto name = "batchnorm"; - auto options = tc::CudaMappingOptions::makeNaiveCudaMappingOptions(); + auto options = tc::CudaMappingOptions::makeNaiveMappingOptions(); std::string cacheFilename = ""; auto bestOptions = - autotune(cacheFilename, TC, name, inputs, options, {options}); + autotune(TC, name, inputs, cacheFilename, options); } TEST_F(ATenCompilationUnitTest, GroupConvolution) { @@ -227,11 +236,11 @@ def group_convolution(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(G,F) B) )"; auto name = "group_convolution"; - auto options = tc::CudaMappingOptions::makeNaiveCudaMappingOptions(); + auto options = tc::CudaMappingOptions::makeNaiveMappingOptions(); std::string cacheFilename = ""; auto bestOptions = - autotune(cacheFilename, TC, name, inputs, options, {options}); + autotune(TC, name, inputs, cacheFilename, options); } int main(int argc, char** argv) { From 27fc85cdb4bee9b9fb0c91b8cd833192e84724bb Mon Sep 17 00:00:00 2001 From: math-fehr Date: Mon, 7 May 2018 14:58:47 +0200 Subject: [PATCH 10/16] Added batchnorm tests in test_caffe2 --- test/caffe2/cuda/test_caffe2.cc | 136 ++++++++++++++++++++++++++++++++ test/cuda/test_autotuner.cc | 7 +- 2 files changed, 138 insertions(+), 5 deletions(-) diff --git a/test/caffe2/cuda/test_caffe2.cc b/test/caffe2/cuda/test_caffe2.cc index ec72b5db6..d33f1c332 100644 --- a/test/caffe2/cuda/test_caffe2.cc +++ b/test/caffe2/cuda/test_caffe2.cc @@ -499,6 +499,142 @@ def fun(float(B, N, M) X, float(B, M, K) Y) -> (Z) CheckEqual(w_ref, w_test, "Z", 1e-6); } +TEST_F(Caffe2Test, TcBatchNormNoGrid) { + static constexpr uint32_t N = 32, C = 4, H = 256, W = 256; + auto init_ws = [&](Workspace& w) { + auto AddInput = AddDeterministicallyRandomInput; + AddInput(w, {N, C, H, W}, "I"); + AddInput(w, {C}, "rMeanIn"); + AddInput(w, {C}, "rVarIn"); + }; + + auto tc = R"TC( + def batchnormnogrid(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) { + mean(c) +=! I(nn, c, hh, ww) + mean(c) = mean(c) / (N * H * W) + rMeanOut(c) = (1 - 0.2) * rMeanIn(c) + 0.2 * mean(c) + centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) + variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) + expectedVariance(c) +=! (variance(n, c, h, w) + 0.001) / (N * H * W) + rVarOut(c) = rsqrt((1 - 0.2) * rVarIn(c) + 0.2 * expectedVariance(c)) + O(n, c, h, w) = centered(n, c, h, w) / rVarOut(c) + normalizedOut(n, c, h, w) = O(n, c, h, w) + } +)TC"; + + Workspace w_test; + init_ws(w_test); + Argument tcArg = MakeArgument("tc_def", tc); + Argument tcNameArg = MakeArgument("tc_name", "batchnormnogrid"); + CudaMappingOptions options = + tc::makeBaseCliStrategy() + .outerScheduleFusionStrategy(tc::FusionStrategy::Max) + .outerScheduleAllowSkewing(false) + .outerSchedulePositiveOrthant(true) + .intraTileScheduleFusionStrategy(tc::FusionStrategy::Min) + .intraTileScheduleAllowSkewing(false) + .intraTileSchedulePositiveOrthant(true) + .tile(2, 1, 8, 1) + .unroll(1) + .tileImperfectlyNested(false) + .matchLibraryCalls(true) + .mapToThreads(256, 2) + .mapToBlocks(256) + .useSharedMemory(true) + .usePrivateMemory(false) + .unrollCopyShared(true); + Argument strategyArg = MakeArgument( + "mappingOptions", + tc::makeCliStrategy(options).toProtobufSerializedString()); + auto op_def = MakeOperatorDef( + "TcOp", + {"I", "rMeanIn", "rVarIn"}, + {"O", + "rMeanOut", + "rVarOut", + "mean", + "centered", + "variance", + "expectedVariance", + "normalizedOut"}, + {tcArg, tcNameArg, strategyArg}); + auto op = CreateOperator(op_def, &w_test); + ASSERT_TRUE(op.get()); + ASSERT_TRUE(op->Run()); + + TC_CUDA_RUNTIMEAPI_ENFORCE(cudaDeviceSynchronize()); + + // CheckEqual(w_ref, w_test, "Z", 1e-6); +} + +TEST_F(Caffe2Test, TcBatchNormGrid) { + static constexpr uint32_t N = 32, C = 4, H = 256, W = 256; + auto init_ws = [&](Workspace& w) { + auto AddInput = AddDeterministicallyRandomInput; + AddInput(w, {N, C, H, W}, "I"); + AddInput(w, {C}, "rMeanIn"); + AddInput(w, {C}, "rVarIn"); + }; + + auto tc = R"TC( + def batchnormgrid(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) { + mean(c) +=! I(nn, c, hh, ww) + mean(c) = mean(c) / (N * H * W) + rMeanOut(c) = (1 - 0.2) * rMeanIn(c) + 0.2 * mean(c) + centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) + variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) + expectedVariance(c) +=! (variance(n, c, h, w) + 0.001) / (N * H * W) + rVarOut(c) = rsqrt((1 - 0.2) * rVarIn(c) + 0.2 * expectedVariance(c)) + O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) + normalizedOut(n, c, h, w) = O(n, c, h, w) + } +)TC"; + + Workspace w_test; + init_ws(w_test); + Argument tcArg = MakeArgument("tc_def", tc); + Argument tcNameArg = MakeArgument("tc_name", "batchnormgrid"); + CudaMappingOptions options = + tc::makeBaseCliStrategy() + .outerScheduleFusionStrategy(tc::FusionStrategy::Preserve3Coincident) + .outerScheduleAllowSkewing(false) + .outerSchedulePositiveOrthant(true) + .intraTileScheduleFusionStrategy(tc::FusionStrategy::Min) + .intraTileScheduleAllowSkewing(false) + .intraTileSchedulePositiveOrthant(true) + .tile(2, 64) + .unroll(2) + .tileImperfectlyNested(false) + .matchLibraryCalls(true) + .mapToThreads(32, 2) + .mapToBlocks(128) + .useSharedMemory(false) + .usePrivateMemory(true) + .unrollCopyShared(false); + Argument strategyArg = MakeArgument( + "mappingOptions", + tc::makeCliStrategy(options).toProtobufSerializedString()); + auto op_def = MakeOperatorDef( + "TcOp", + {"I", "rMeanIn", "rVarIn"}, + {"O", + "rMeanOut", + "rVarOut", + "mean", + "centered", + "variance", + "expectedVariance", + "normalizedOut"}, + {tcArg, tcNameArg, strategyArg}); + auto op = CreateOperator(op_def, &w_test); + ASSERT_TRUE(op.get()); + ASSERT_TRUE(op->Run()); + + TC_CUDA_RUNTIMEAPI_ENFORCE(cudaDeviceSynchronize()); + + // CheckEqual(w_ref, w_test, "Z", 1e-6); +} + // TODO: TEST_F(Caffe2Test, DISABLED_TcGather) { auto init_ws = [&](Workspace& w) { diff --git a/test/cuda/test_autotuner.cc b/test/cuda/test_autotuner.cc index 3ad750b41..e03f994a2 100644 --- a/test/cuda/test_autotuner.cc +++ b/test/cuda/test_autotuner.cc @@ -208,13 +208,11 @@ def batchnorm( normalizedOut(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) })TC"; - auto name = "batchnorm"; auto options = tc::CudaMappingOptions::makeNaiveMappingOptions(); std::string cacheFilename = ""; - auto bestOptions = - autotune(TC, name, inputs, cacheFilename, options); + auto bestOptions = autotune(TC, name, inputs, cacheFilename, options); } TEST_F(ATenCompilationUnitTest, GroupConvolution) { @@ -239,8 +237,7 @@ def group_convolution(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(G,F) B) auto options = tc::CudaMappingOptions::makeNaiveMappingOptions(); std::string cacheFilename = ""; - auto bestOptions = - autotune(TC, name, inputs, cacheFilename, options); + auto bestOptions = autotune(TC, name, inputs, cacheFilename, options); } int main(int argc, char** argv) { From 8084e97856170197aa70c29a23516eb84c4a5fa4 Mon Sep 17 00:00:00 2001 From: math-fehr Date: Fri, 11 May 2018 13:17:39 +0200 Subject: [PATCH 11/16] Changed promotion to apply heuristic on all bands at the same time. Instead of applying it on every branch one by one greedily. --- .../cuda/memory_promotion_heuristic.cc | 180 ++++++++++-------- 1 file changed, 103 insertions(+), 77 deletions(-) diff --git a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc index d153ef08e..8f6f35e65 100644 --- a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc +++ b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc @@ -29,6 +29,7 @@ #include #include #include +#include namespace tc { namespace polyhedral { @@ -476,6 +477,13 @@ std::vector bandsSplitAfterDepth( return functional::Map(splitAtDepth, bands); } +struct IslIdSizeTHash { + size_t operator()(const std::pair& k) const { + return (isl::IslIdIslHash()(k.first) ^ (k.second << 1)); + } +}; +} // namespace + /* * For every place in the schedule tree where schedule depth (i.e., the number * of preceding band members) is "depth", promote tensor reference groups to @@ -524,89 +532,107 @@ size_t promoteToSharedGreedy( // group either features reuse or is accessed in a non-coalesced way, or // both. size_t remainingMemory = maxMemory; - for (auto bandNode : bands) { - auto activePoints = activeDomainPoints(root, bandNode); - auto partialSched = partialSchedule(root, bandNode); - - auto groupMap = TensorReferenceGroup::accessedWithin( - partialSched.intersect_domain(activePoints), scop.reads, scop.writes); - // Pure affine schedule without (mapping) filters. - auto partialSchedMupa = partialScheduleMupa(root, bandNode); - - // Prepare groups for sorting, to have specified order necessary for - // reproducibility and tests. - using TensorGroupList = std::pair; - std::vector groupLists( - std::make_move_iterator(groupMap.begin()), - std::make_move_iterator(groupMap.end())); - - // Computes the total number of references in all groups. - auto refsCount = [](const TensorGroupsInfo& info) { - size_t refs = 0; - for (auto const& group : info) { - refs += group->referenceIds().size(); - } - return refs; - }; + std::vector partialScheds; + // Pure affine schedule without (mapping) filters. + std::vector partialSchedsMupa; + std::unordered_map< + std::pair, + TensorGroupsInfo, + IslIdSizeTHash> + groupMap; + + for (size_t i = 0; i < bands.size(); ++i) { + auto activePoints = activeDomainPoints(root, bands[i]); + auto groupMapOfSubtree = TensorReferenceGroup::accessedWithin(partialSched.intersect_domain(activePoints), scop.reads, scop.writes); + partialScheds.push_back(partialSchedule(root, bands[i])); + partialSchedsMupa.push_back(partialScheduleMupa(root, bands[i])); + + for (auto& tensorGroup : groupMapOfSubtree) { + groupMap[std::make_pair(tensorGroup.first, i)] = + std::move(tensorGroup.second); + } + } - // Sort by the total number of references, then by name. Because names are - // guarenteed to be unique, the order is total. + // Prepare groups for sorting, to have specified order necessary for + // reproducibility and tests. + using TensorGroupList = + std::pair, TensorGroupsInfo>; + std::vector groupLists( + std::make_move_iterator(groupMap.begin()), + std::make_move_iterator(groupMap.end())); + + // Computes the total number of references in all groups. + auto refsCount = [](const TensorGroupsInfo& info) { + size_t refs = 0; + for (auto const& group : info) { + refs += group->referenceIds().size(); + } + return refs; + }; + + // Sort by the total number of references, then by name. Because names are + // guarenteed to be unique, the order is total. + std::sort( + groupLists.begin(), + groupLists.end(), + [refsCount](const TensorGroupList& l1, const TensorGroupList& l2) { + auto r1 = refsCount(l1.second); + auto r2 = refsCount(l2.second); + auto n1 = l1.first.first.get_name(); + auto n2 = l2.first.first.get_name(); + return std::tie(r1, n1, l1.first.second) < + std::tie(r2, n2, l2.first.second); + }); + + for (auto& tensorGroups : groupLists) { + auto tensorId = tensorGroups.first.first; + auto bandNodeId = tensorGroups.first.second; + // Sort the reference groups to prioritize groups with more references as + // they are more likely to benefit from promotion. std::sort( - groupLists.begin(), - groupLists.end(), - [refsCount](const TensorGroupList& l1, const TensorGroupList& l2) { - auto r1 = refsCount(l1.second); - auto r2 = refsCount(l2.second); - return r1 == r2 ? l1.first.get_name() < l2.first.get_name() : r1 < r2; + tensorGroups.second.begin(), + tensorGroups.second.end(), + [refsCount]( + const std::unique_ptr& group1, + const std::unique_ptr& group2) { + return group1->referenceIds().size() > group2->referenceIds().size(); }); - for (auto& tensorGroups : groupLists) { - auto tensorId = tensorGroups.first; - // Sort the reference groups to prioritize groups with more references as - // they are more likely to benefit from promotion. - std::sort( - tensorGroups.second.begin(), - tensorGroups.second.end(), - [refsCount]( - const std::unique_ptr& group1, - const std::unique_ptr& group2) { - return group1->referenceIds().size() > - group2->referenceIds().size(); - }); - - for (auto& group : tensorGroups.second) { - auto sizes = group->approximationSizes(); - if (sizes.size() == 0) { - throw promotion::PromotionLogicError("cannot promote a scalar"); - } - if (sizes.back() % 2 == 0) { - sizes.back() += 1; - } - auto nApproximationElements = std::accumulate( - sizes.begin(), sizes.end(), 1, std::multiplies()); - size_t memoryRequirement = - nApproximationElements * scop.findArgument(tensorId).type().bytes(); - if (memoryRequirement > remainingMemory) { - continue; - } - // Do not promote if the group features no reuse and is accessed in a - // coalesced way. - if (!hasReuseWithin(*group, partialSchedMupa) && - !promotionImprovesCoalescing(root, bandNode, *group, fullSched)) { - continue; - } - scop.promoteGroup( - Scop::PromotedDecl::Kind::SharedMem, - tensorId, - std::move(group), - bandNode, - partialSched, - true); - remainingMemory -= memoryRequirement; + for (auto& group : tensorGroups.second) { + auto sizes = group->approximationSizes(); + if (sizes.size() == 0) { + throw promotion::PromotionLogicError("cannot promote a scalar"); + } + if (sizes.back() % 2 == 0) { + sizes.back() += 1; + } + auto nApproximationElements = std::accumulate( + sizes.begin(), sizes.end(), 1, std::multiplies()); + size_t memoryRequirement = + nApproximationElements * scop.findArgument(tensorId).type().bytes(); + if (memoryRequirement > remainingMemory) { + continue; + } + // Do not promote if the group features no reuse and is accessed in a + // coalesced way. + if (!hasReuseWithin(*group, partialSchedsMupa[bandNodeId]) && + !promotionImprovesCoalescing( + root, bands[bandNodeId], *group, fullSched)) { + continue; } + + scop.promoteGroup( + Scop::PromotedDecl::Kind::SharedMem, + tensorId, + std::move(group), + bands[bandNodeId], + partialScheds[bandNodeId], + true); + remainingMemory -= memoryRequirement; } - LOG(WARNING) << *root; - LOG(WARNING) << *bandNode; + } + + for (auto bandNode : bands) { scop.insertSyncsAroundCopies(bandNode); } return remainingMemory; From 5bd3ef4cdecfd083076125f92f5c26d2ad69e79a Mon Sep 17 00:00:00 2001 From: math-fehr Date: Wed, 16 May 2018 18:12:11 +0200 Subject: [PATCH 12/16] Fixed bug when looking for outermost coincident bands. Bands should not be inserted between a sequence and a filter. Also, depth wasn't computed correctly. --- tc/core/polyhedral/cuda/mapped_scop.cc | 11 ++++++++--- tc/core/polyhedral/scop.cc | 3 ++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 6a242302b..0d1c834c7 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -1077,7 +1077,8 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( // 4. Optionally reschedule if point loops need a different strategy than // tile loops for (auto outerBand : tiledBands) { - if (generic.outerScheduleOptions != generic.intraTileScheduleOptions) { + if (generic.outerScheduleOptions != generic.intraTileScheduleOptions && + outerBand->numChildren() != 0) { scop->reschedule(outerBand->child({0}), generic.intraTileScheduleOptions); } } @@ -1166,8 +1167,12 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( continue; } bandsWithPromotion.push_back(band); - depths.push_back(std::min( - bandElem->nOuterCoincident(), mappedScop->numBlocks.view.size())); + auto depthBefore = band->scheduleDepth(scop->scheduleRoot()); + depths.push_back( + depthBefore + + std::min( + bandElem->nOuterCoincident(), + mappedScop->numBlocks.view.size())); } sharedMemorySize = promoteGreedilyAtDepth( diff --git a/tc/core/polyhedral/scop.cc b/tc/core/polyhedral/scop.cc index cd9e3816a..44fcc55de 100644 --- a/tc/core/polyhedral/scop.cc +++ b/tc/core/polyhedral/scop.cc @@ -501,7 +501,8 @@ std::vector obtainOuterCoincidentBands( auto domain = root->elemAs(); CHECK(domain); auto band = ScheduleTree::makeEmptyBand(root); - if (nChildren == 0 || root == tree) { + if (nChildren == 0 || tree->elemAs() || + tree->elemAs()) { return {setPermutable(insertNodeBelow(tree, std::move(band)))}; } else { return {setPermutable(insertNodeAbove(root, tree, std::move(band)))}; From 5e2e586a10be1f12a18b10e80761cc45e3e739de Mon Sep 17 00:00:00 2001 From: math-fehr Date: Fri, 25 May 2018 13:04:05 +0200 Subject: [PATCH 13/16] Added reduce_launch flag and set grid sync default to false. --- tc/autotuner/parameters.cc | 32 ++++++++++++++++++++++---------- tc/core/flags.cc | 3 ++- tc/core/flags.h | 1 + 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/tc/autotuner/parameters.cc b/tc/autotuner/parameters.cc index 1d421e495..4363c303a 100644 --- a/tc/autotuner/parameters.cc +++ b/tc/autotuner/parameters.cc @@ -347,30 +347,42 @@ TuningConfiguration::TuningConfiguration() useReadOnlyCache("use readonly cache (i.e. emit __ldg loads)"), matchLibraryCalls("match library calls") { addValidator([](const TuningConfiguration& conf) { - auto b0v = conf.blockParams.dims.at(0).value(); - auto b1v = conf.blockParams.dims.at(1).value(); - auto b2v = conf.blockParams.dims.at(2).value(); + auto b = conf.blockParams; + auto b0v = b.dims.at(0).value(); + auto b1v = b.dims.at(1).value(); + auto b2v = b.dims.at(2).value(); + auto g = conf.gridParams; + auto g0v = g.dims.at(0).value(); + auto g1v = g.dims.at(1).value(); + auto g2v = g.dims.at(2).value(); if (b0v <= 0 or b0v > 1024 or b1v <= 0 or b1v > 1024 or b2v <= 0 or b2v > 64) { return false; } - auto blockProduct = [&]() { - switch (conf.blockParams.numberDims.value()) { + auto computeProduct = [&](const CudaDimParameters& p) { + switch (p.numberDims.value()) { case 3: - return b0v * b1v * b2v; + return p.dims.at(0).value() * p.dims.at(1).value() * + p.dims.at(2).value(); case 2: - return b0v * b1v; + return p.dims.at(0).value() * p.dims.at(1).value(); case 1: - return b0v; + return p.dims.at(0).value(); default: TC_CHECK(false) << "Must have (1-3) block dims, got: " << conf.blockParams.numberDims.value(); } - return b0v; - }(); + return p.dims.at(0).value(); + }; + auto blockProduct = computeProduct(b); + auto gridProduct = computeProduct(g); if (blockProduct < 32 or blockProduct > 512) { return false; } + if (FLAGS_reduce_launch_size and + (gridProduct > 128 or blockProduct > 256)) { + return false; + } return true; }); } diff --git a/tc/core/flags.cc b/tc/core/flags.cc index d1f342426..b12c74ea8 100644 --- a/tc/core/flags.cc +++ b/tc/core/flags.cc @@ -37,7 +37,8 @@ DEFINE_bool( "Print debug spew for the tc_mapper like cuda code, mapping options etc"); DEFINE_bool(dump_cuda, false, "Print the generated source"); DEFINE_bool(dump_ptx, false, "Dump the generated PTX"); -DEFINE_bool(grid_sync, true, "Use the grid sync feature."); +DEFINE_bool(grid_sync, false, "Use the grid sync feature."); +DEFINE_bool(reduce_launch_size, false, "Reduce the launch size."); // CPU codegen options DEFINE_bool(llvm_dump_before_opt, false, "Print IR before optimization"); diff --git a/tc/core/flags.h b/tc/core/flags.h index 86abcc1ab..e49d14b3a 100644 --- a/tc/core/flags.h +++ b/tc/core/flags.h @@ -31,6 +31,7 @@ DECLARE_bool(debug_tuner); DECLARE_bool(dump_cuda); DECLARE_bool(dump_ptx); DECLARE_bool(grid_sync); +DECLARE_bool(reduce_launch_size); // llvm codegen DECLARE_bool(llvm_dump_before_opt); From 294ae70dd3f7cccac0a652f10ec372006184a20f Mon Sep 17 00:00:00 2001 From: math-fehr Date: Fri, 25 May 2018 17:34:27 +0200 Subject: [PATCH 14/16] Compile tc_cuda_version only when using CUDA, and fix bug --- tc/core/CMakeLists.txt | 8 +++++++- tc/version/CMakeLists.txt | 3 +-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tc/core/CMakeLists.txt b/tc/core/CMakeLists.txt index ff64afd31..f8b6089e7 100644 --- a/tc/core/CMakeLists.txt +++ b/tc/core/CMakeLists.txt @@ -45,12 +45,18 @@ target_link_libraries( -lLLVM tc_lang - tc_cuda_version + tc_version tc_proto ) + if (WITH_BINDINGS) add_dependencies(tc_core generate_isl_cpp_h) endif() + +if(WITH_CUDA) + target_link_libraries(tc_cuda_version) +endif() + install( TARGETS tc_core diff --git a/tc/version/CMakeLists.txt b/tc/version/CMakeLists.txt index 1ad1e61a3..31b400618 100644 --- a/tc/version/CMakeLists.txt +++ b/tc/version/CMakeLists.txt @@ -11,8 +11,7 @@ target_include_directories(tc_version INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) if (WITH_CUDA) find_library(CUDA_LIBDEVRT_PATH libcudadevrt.a HINTS ${CUDA_TOOLKIT_ROOT_DIR} - PATH_SUFFIXES lib lib64) - SET(CUDA_LIBDEVRT_PATH ${CUDA_LIBDEVRT_PATH}) + PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs) configure_file(cuda_version.cc.in cuda_version.cc) add_library(tc_cuda_version STATIC ${CMAKE_CURRENT_BINARY_DIR}/cuda_version.cc) SET_TARGET_PROPERTIES(tc_cuda_version PROPERTIES LINKER_LANGUAGE CXX) From da18b08ae3f7e206c7fc3200ce89b6509e06c456 Mon Sep 17 00:00:00 2001 From: math-fehr Date: Tue, 19 Jun 2018 16:29:37 +0200 Subject: [PATCH 15/16] Fixed bug modifying the tile size for some outer coincident bands --- tc/core/polyhedral/scop.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tc/core/polyhedral/scop.cc b/tc/core/polyhedral/scop.cc index 44fcc55de..c7829fee2 100644 --- a/tc/core/polyhedral/scop.cc +++ b/tc/core/polyhedral/scop.cc @@ -561,9 +561,9 @@ std::vector Scop::tileOuterCoincidentBands( const TilingView& tileSizes) { using namespace tc::polyhedral::detail; auto bands = obtainOuterCoincidentBands(scheduleRoot(), scheduleRoot()); - std::vector sizes = tileSizes.extractVector(); std::vector tiledBands = {}; for (auto band : bands) { + std::vector sizes = tileSizes.extractVector(); auto bandNode = band->elemAs(); if (bandNode->nMember() < sizes.size()) { sizes.resize(bandNode->nMember()); From 563614b9abd203433e33441576d092cdca41facb Mon Sep 17 00:00:00 2001 From: math-fehr Date: Tue, 19 Jun 2018 16:31:34 +0200 Subject: [PATCH 16/16] Fixed bug with division by 0 --- tc/core/polyhedral/cuda/mapped_scop.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 0d1c834c7..53102529b 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -1068,8 +1068,9 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( std::vector tiledBands; if (useGridSync) { tiledBands = scop->tileOuterCoincidentBands(generic.tiling); - sharedMemorySize = - std::min(sharedMemorySize, sharedMemorySizePerSM / blocksPerSM); + sharedMemorySize = std::min( + sharedMemorySize, + blocksPerSM == 0 ? 0 : sharedMemorySizePerSM / blocksPerSM); } else { tiledBands = {scop->tileOuterBand(generic.tiling)}; }