diff --git a/tc/autotuner/parameters.cc b/tc/autotuner/parameters.cc index 1d421e495..4363c303a 100644 --- a/tc/autotuner/parameters.cc +++ b/tc/autotuner/parameters.cc @@ -347,30 +347,42 @@ TuningConfiguration::TuningConfiguration() useReadOnlyCache("use readonly cache (i.e. emit __ldg loads)"), matchLibraryCalls("match library calls") { addValidator([](const TuningConfiguration& conf) { - auto b0v = conf.blockParams.dims.at(0).value(); - auto b1v = conf.blockParams.dims.at(1).value(); - auto b2v = conf.blockParams.dims.at(2).value(); + auto b = conf.blockParams; + auto b0v = b.dims.at(0).value(); + auto b1v = b.dims.at(1).value(); + auto b2v = b.dims.at(2).value(); + auto g = conf.gridParams; + auto g0v = g.dims.at(0).value(); + auto g1v = g.dims.at(1).value(); + auto g2v = g.dims.at(2).value(); if (b0v <= 0 or b0v > 1024 or b1v <= 0 or b1v > 1024 or b2v <= 0 or b2v > 64) { return false; } - auto blockProduct = [&]() { - switch (conf.blockParams.numberDims.value()) { + auto computeProduct = [&](const CudaDimParameters& p) { + switch (p.numberDims.value()) { case 3: - return b0v * b1v * b2v; + return p.dims.at(0).value() * p.dims.at(1).value() * + p.dims.at(2).value(); case 2: - return b0v * b1v; + return p.dims.at(0).value() * p.dims.at(1).value(); case 1: - return b0v; + return p.dims.at(0).value(); default: TC_CHECK(false) << "Must have (1-3) block dims, got: " << conf.blockParams.numberDims.value(); } - return b0v; - }(); + return p.dims.at(0).value(); + }; + auto blockProduct = computeProduct(b); + auto gridProduct = computeProduct(g); if (blockProduct < 32 or blockProduct > 512) { return false; } + if (FLAGS_reduce_launch_size and + (gridProduct > 128 or blockProduct > 256)) { + return false; + } return true; }); } diff --git a/tc/core/CMakeLists.txt b/tc/core/CMakeLists.txt index 4435f1533..f8b6089e7 100644 --- a/tc/core/CMakeLists.txt +++ b/tc/core/CMakeLists.txt @@ -48,9 +48,15 @@ target_link_libraries( tc_version tc_proto ) + if (WITH_BINDINGS) add_dependencies(tc_core generate_isl_cpp_h) endif() + +if(WITH_CUDA) + target_link_libraries(tc_cuda_version) +endif() + install( TARGETS tc_core @@ -176,6 +182,7 @@ if (WITH_CUDA) tc_lang tc_version + tc_cuda_version tc_proto tc_core ) diff --git a/tc/core/constants.h b/tc/core/constants.h index 8ec277583..92c779209 100644 --- a/tc/core/constants.h +++ b/tc/core/constants.h @@ -29,6 +29,7 @@ constexpr auto kReadIdName = "read"; constexpr auto kWriteIdName = "write"; constexpr auto kSyncIdPrefix = "_sync_"; constexpr auto kWarpSyncIdPrefix = "_warpSync_"; +constexpr auto kGridSyncIdPrefix = "_gridSync_"; } // namespace polyhedral } // namespace tc diff --git a/tc/core/cuda/cuda.cc b/tc/core/cuda/cuda.cc index 108e058cd..6fa492f74 100644 --- a/tc/core/cuda/cuda.cc +++ b/tc/core/cuda/cuda.cc @@ -30,7 +30,14 @@ DEFINE_bool(use_nvprof, false, "Start / stop nvprof"); namespace { -std::tuple, std::vector> init() { +std::tuple< + std::vector, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector> +init() { int deviceCount = 0; auto err_id = cudaGetDeviceCount(&deviceCount); if (err_id == 35 or err_id == 30) { @@ -44,14 +51,36 @@ std::tuple, std::vector> init() { } std::vector gpuNames; std::vector sharedMemSizes; + std::vector sharedMemSizesPerSM; + std::vector blocksPerSM; + std::vector threadsPerSM; + std::vector nbOfSM; gpuNames.reserve(deviceCount); for (int i = 0; i < deviceCount; ++i) { cudaDeviceProp deviceProp; TC_CUDA_RUNTIMEAPI_ENFORCE(cudaGetDeviceProperties(&deviceProp, i)); gpuNames.emplace_back(deviceProp.name); sharedMemSizes.emplace_back(deviceProp.sharedMemPerBlock); + sharedMemSizesPerSM.emplace_back(deviceProp.sharedMemPerMultiprocessor); + + // There is currently no way to get the number of blocks per sm + // with the CUDA api. The only relevant solution is to compute it + // with the compute capability. + // the formula works if the number of blocks per sm is nondecreasing after + // the 6.0 compute capability. + auto major = deviceProp.major; + blocksPerSM.emplace_back(major < 3 ? 8 : (major < 4 ? 16 : 32)); + + threadsPerSM.emplace_back(deviceProp.maxThreadsPerMultiProcessor); + nbOfSM.emplace_back(deviceProp.multiProcessorCount); } - return std::make_tuple(gpuNames, sharedMemSizes); + return std::make_tuple( + gpuNames, + sharedMemSizes, + sharedMemSizesPerSM, + blocksPerSM, + threadsPerSM, + nbOfSM); } } // namespace @@ -61,8 +90,13 @@ CudaGPUInfo& CudaGPUInfo::GPUInfo() { static thread_local bool inited = false; if (!inited) { auto infos = init(); - pInfo = std::unique_ptr( - new CudaGPUInfo(std::get<0>(infos), std::get<1>(infos))); + pInfo = std::unique_ptr(new CudaGPUInfo( + std::get<0>(infos), + std::get<1>(infos), + std::get<2>(infos), + std::get<3>(infos), + std::get<4>(infos), + std::get<5>(infos))); inited = true; } return *pInfo; @@ -102,4 +136,33 @@ size_t CudaGPUInfo::SharedMemorySize() const { } return sharedMemSizes_.at(CurrentGPUId()); } + +size_t CudaGPUInfo::SharedMemorySizePerSM() const { + if (NumberGPUs() == 0) { + return 0; // no shared memory per sm if no GPUs + } + return sharedMemSizesPerSM_.at(CurrentGPUId()); +} + +size_t CudaGPUInfo::BlocksPerSM() const { + if (NumberGPUs() == 0) { + return 0; // no blocks per sm if no GPUs + } + return blocksPerSM_.at(CurrentGPUId()); +} + +size_t CudaGPUInfo::ThreadsPerSM() const { + if (NumberGPUs() == 0) { + return 0; // no threads per sm if no GPUs + } + return threadsPerSM_.at(CurrentGPUId()); +} + +size_t CudaGPUInfo::NbOfSM() const { + if (NumberGPUs() == 0) { + return 0; // no sm if no GPUs + } + return nbOfSM_.at(CurrentGPUId()); +} + } // namespace tc diff --git a/tc/core/cuda/cuda.h b/tc/core/cuda/cuda.h index 70406f6a4..30d9fbe65 100644 --- a/tc/core/cuda/cuda.h +++ b/tc/core/cuda/cuda.h @@ -96,8 +96,17 @@ struct WithCudaDevice { class CudaGPUInfo { CudaGPUInfo( const std::vector& gpuNames, - const std::vector& sharedMemSizes) - : gpuNames_(gpuNames), sharedMemSizes_(sharedMemSizes) {} + const std::vector& sharedMemSizes, + const std::vector& sharedMemSizesPerSM, + const std::vector& blocksPerSM, + const std::vector& threadsPerSM, + const std::vector& nbOfSM) + : gpuNames_(gpuNames), + sharedMemSizes_(sharedMemSizes), + sharedMemSizesPerSM_(sharedMemSizesPerSM), + blocksPerSM_(blocksPerSM), + threadsPerSM_(threadsPerSM), + nbOfSM_(nbOfSM) {} public: static CudaGPUInfo& GPUInfo(); @@ -110,9 +119,17 @@ class CudaGPUInfo { std::string GetGPUName(int id = -1) const; std::string getCudaDeviceStr() const; size_t SharedMemorySize() const; + size_t SharedMemorySizePerSM() const; + size_t BlocksPerSM() const; + size_t ThreadsPerSM() const; + size_t NbOfSM() const; std::vector gpuNames_; std::vector sharedMemSizes_; + std::vector sharedMemSizesPerSM_; + std::vector blocksPerSM_; + std::vector threadsPerSM_; + std::vector nbOfSM_; }; struct CudaProfiler { diff --git a/tc/core/cuda/cuda_backend.h b/tc/core/cuda/cuda_backend.h index f2cfb239a..3faf2dcc9 100644 --- a/tc/core/cuda/cuda_backend.h +++ b/tc/core/cuda/cuda_backend.h @@ -37,6 +37,7 @@ struct CudaCompilationResult { std::vector parameters; Grid grid; Block block; + bool useGridSync; }; /** diff --git a/tc/core/cuda/cuda_libraries.h b/tc/core/cuda/cuda_libraries.h index c2fdb1de8..2f745cef6 100644 --- a/tc/core/cuda/cuda_libraries.h +++ b/tc/core/cuda/cuda_libraries.h @@ -61,6 +61,12 @@ __device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {} #endif )C"; +constexpr auto gridSyncFunctions = R"C( +__device__ void __syncgrid() { + cudaCGSynchronize(cudaCGGetIntrinsicHandle(cudaCGScopeGrid),0); +} +)C"; + constexpr auto mathFunctionDecl = R"C( // BEGIN MATH FUNCTIONS FROM CUDA diff --git a/tc/core/cuda/cuda_rtc.cc b/tc/core/cuda/cuda_rtc.cc index 412e397dc..31407319f 100644 --- a/tc/core/cuda/cuda_rtc.cc +++ b/tc/core/cuda/cuda_rtc.cc @@ -25,6 +25,7 @@ #include "tc/core/cuda/cuda_rtc.h" #include "tc/core/flags.h" #include "tc/core/scope_guard.h" +#include "tc/version/cuda_version.h" namespace tc { std::mutex nvrtc_mutex; @@ -50,7 +51,8 @@ void CudaRTCFunction::clear() { std::unique_ptr CudaRTCFunction::Compile( const std::string& name, - const std::string& source) { + const std::string& source, + bool useGridSync) { std::unique_ptr res(new CudaRTCFunction()); res->specializedName = name; res->cleared_ = false; @@ -88,6 +90,9 @@ std::unique_ptr CudaRTCFunction::Compile( "-DNVRTC_CUB=1", cudaHome.c_str(), cubHome.c_str()}; + if (useGridSync) { + nvrtcts.push_back("--relocatable-device-code=true"); + } if (FLAGS_debug_cuda) { nvrtcts.push_back(nvrtc_debug_opts[0]); nvrtcts.push_back(nvrtc_debug_opts[1]); @@ -132,6 +137,7 @@ std::ostream& operator<<(std::ostream& os, const std::array& a) { Duration CudaRTCFunction::Launch( const std::array& grid, const std::array& block, + bool useGridSync, unsigned int shared_mem, cudaStream_t stream, std::vector params, @@ -143,8 +149,28 @@ Duration CudaRTCFunction::Launch( if (perGpuModule_.count(dev) == 0) { CUmodule module; CUfunction function; - TC_CUDA_DRIVERAPI_ENFORCE( - cuModuleLoadDataEx(&module, nvrtc_ptx.data(), 0, 0, 0)); + if (useGridSync) { + CUlinkState linkState; + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkCreate(0, 0, 0, &linkState)); + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkAddFile( + linkState, CU_JIT_INPUT_LIBRARY, cuda_libdevrt_path, 0, 0, 0)); + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkAddData( + linkState, + CU_JIT_INPUT_PTX, + (void*)nvrtc_ptx.data(), + nvrtc_ptx.size(), + "device_code.ptx", + 0, + 0, + 0)); + size_t cubinSize; + void* cubin; + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkComplete(linkState, &cubin, &cubinSize)); + TC_CUDA_DRIVERAPI_ENFORCE(cuModuleLoadData(&module, cubin)); + } else { + TC_CUDA_DRIVERAPI_ENFORCE( + cuModuleLoadDataEx(&module, nvrtc_ptx.data(), 0, 0, 0)); + } perGpuModule_.emplace(dev, module); TC_CUDA_DRIVERAPI_ENFORCE( cuModuleGetFunction(&function, module, specializedName.c_str())); @@ -174,18 +200,32 @@ Duration CudaRTCFunction::Launch( unsigned int by = block[1]; unsigned int bz = block[2]; auto launch = [&]() { - TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel( - perGpuKernel_.at(dev), - gx, - gy, - gz, - bx, - by, - bz, - shared_mem, - stream, - args_voidp.data(), - 0)); + if (useGridSync) { + TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchCooperativeKernel( + perGpuKernel_.at(dev), + gx, + gy, + gz, + bx, + by, + bz, + shared_mem, + stream, + args_voidp.data())); + } else { + TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel( + perGpuKernel_.at(dev), + gx, + gy, + gz, + bx, + by, + bz, + shared_mem, + stream, + args_voidp.data(), + 0)); + } }; if (not profile) { diff --git a/tc/core/cuda/cuda_rtc.h b/tc/core/cuda/cuda_rtc.h index 4aa5e831f..25f46ee0d 100644 --- a/tc/core/cuda/cuda_rtc.h +++ b/tc/core/cuda/cuda_rtc.h @@ -41,14 +41,14 @@ class CudaRTCFunction { public: ~CudaRTCFunction(); - static std::unique_ptr Compile( - const std::string& name, - const std::string& source); + static std::unique_ptr + Compile(const std::string& name, const std::string& source, bool useGridSync); // if profile is set it returns the kernel runtime Duration Launch( const std::array& grid, const std::array& block, + bool useGridSync, unsigned int shared_mem, cudaStream_t stream, // by copy because we take an address to element when calling the kernel diff --git a/tc/core/cuda/cuda_tc_executor.cc b/tc/core/cuda/cuda_tc_executor.cc index cf27d5f29..e35d1b867 100644 --- a/tc/core/cuda/cuda_tc_executor.cc +++ b/tc/core/cuda/cuda_tc_executor.cc @@ -57,10 +57,13 @@ CudaTcExecutor::CudaTcExecutor( // force unloading in case we JIT with the same name/input/outputs with // different options. this->clearRuntimeCompiledFunction(); - rtcFun_ = CudaRTCFunction::Compile( - compilationResult.specializedName, compilationResult.source); grid_ = compilationResult.grid; block_ = compilationResult.block; + useGridSync_ = compilationResult.useGridSync; + rtcFun_ = CudaRTCFunction::Compile( + compilationResult.specializedName, + compilationResult.source, + useGridSync_); auto t1 = std::chrono::high_resolution_clock::now(); LOG_IF(INFO, FLAGS_debug_tc_mapper) << "[COMPILE] Compiling with host JIT compiler took: " @@ -100,12 +103,14 @@ CudaCompilationResult CudaBackend::compileWithTcMapper( std::string source; Grid grid; Block block; - std::tie(source, grid, block) = mappedScop->codegen(specializedName); + bool useGridSync; + std::tie(source, grid, block, useGridSync) = + mappedScop->codegen(specializedName); LOG_IF(INFO, FLAGS_dump_cuda) << "generatedCuda: " << source << "\n" << "grid: " << grid << " block: " << block; return CudaCompilationResult{ - source, specializedName, parameters, grid, block}; + source, specializedName, parameters, grid, block, useGridSync}; } void CudaTcExecutor::uncheckedRun( @@ -118,6 +123,7 @@ void CudaTcExecutor::uncheckedRun( rtcFun_->Launch( grid_.view.extractDefaultedArray(), block_.view.extractDefaultedArray(), + useGridSync_, 0, info.stream, parameters_, @@ -136,6 +142,7 @@ ProfilingInfo CudaTcExecutor::profileUnchecked( Duration kernelRuntime(rtcFun_->Launch( grid_.view.extractDefaultedArray(), block_.view.extractDefaultedArray(), + useGridSync_, 0, stream, parameters_, diff --git a/tc/core/cuda/cuda_tc_executor.h b/tc/core/cuda/cuda_tc_executor.h index a497c940d..bb0b7a911 100644 --- a/tc/core/cuda/cuda_tc_executor.h +++ b/tc/core/cuda/cuda_tc_executor.h @@ -63,5 +63,6 @@ class CudaTcExecutor : public TcExecutor { // GPU-specific results of compilation Grid grid_; Block block_; + bool useGridSync_; }; } // namespace tc diff --git a/tc/core/flags.cc b/tc/core/flags.cc index 80c9e5ec5..b12c74ea8 100644 --- a/tc/core/flags.cc +++ b/tc/core/flags.cc @@ -37,6 +37,8 @@ DEFINE_bool( "Print debug spew for the tc_mapper like cuda code, mapping options etc"); DEFINE_bool(dump_cuda, false, "Print the generated source"); DEFINE_bool(dump_ptx, false, "Dump the generated PTX"); +DEFINE_bool(grid_sync, false, "Use the grid sync feature."); +DEFINE_bool(reduce_launch_size, false, "Reduce the launch size."); // CPU codegen options DEFINE_bool(llvm_dump_before_opt, false, "Print IR before optimization"); diff --git a/tc/core/flags.h b/tc/core/flags.h index c748759a4..e49d14b3a 100644 --- a/tc/core/flags.h +++ b/tc/core/flags.h @@ -30,6 +30,8 @@ DECLARE_bool(debug_cuda); DECLARE_bool(debug_tuner); DECLARE_bool(dump_cuda); DECLARE_bool(dump_ptx); +DECLARE_bool(grid_sync); +DECLARE_bool(reduce_launch_size); // llvm codegen DECLARE_bool(llvm_dump_before_opt); diff --git a/tc/core/gpu.h b/tc/core/gpu.h index cdacc05a3..1e9328f9f 100644 --- a/tc/core/gpu.h +++ b/tc/core/gpu.h @@ -34,4 +34,51 @@ inline size_t querySharedMemorySize() { #endif } +/// Get the shared memory size per sm of the GPU device active in the current +/// thread. +/// The call is forwarded to the appropriate GPU driver (CUDA in particular). +/// If a thread has no associated GPU device, return 0. +inline size_t querySharedMemorySizePerSM() { +#ifdef CUDA_HOME + return CudaGPUInfo::GPUInfo().SharedMemorySizePerSM(); +#else + return 0; +#endif +} + +/// Get the maximum number of blocks per sm of the GPU device active +/// in the current thread. +/// The call is forwarded to the appropriate GPU driver (CUDA in particular). +/// If a thread has no associated GPU device, return 0. +inline size_t queryBlocksPerSM() { +#ifdef CUDA_HOME + return CudaGPUInfo::GPUInfo().BlocksPerSM(); +#else + return 0; +#endif +} + +/// Get the maximum number of threads per sm of the GPU device active +/// in the current thread. +/// The call is forwarded to the appropriate GPU driver (CUDA in particular). +/// If a thread has no associated GPU device, return 0. +inline size_t queryThreadsPerSM() { +#ifdef CUDA_HOME + return CudaGPUInfo::GPUInfo().ThreadsPerSM(); +#else + return 0; +#endif +} + +/// Get the number of sm on the GPU device active in the current thread. +/// The call is forwarded to the appropriate GPU driver (CUDA in particular). +/// If a thread has no associated GPU device, return 0. +inline size_t queryNbOfSM() { +#ifdef CUDA_HOME + return CudaGPUInfo::GPUInfo().NbOfSM(); +#else + return 0; +#endif +} + } // namespace tc diff --git a/tc/core/polyhedral/cuda/codegen.cc b/tc/core/polyhedral/cuda/codegen.cc index 032d5dcab..60eccbc06 100644 --- a/tc/core/polyhedral/cuda/codegen.cc +++ b/tc/core/polyhedral/cuda/codegen.cc @@ -450,6 +450,8 @@ void AstPrinter::emitStmt(isl::ast_node_user node) { context_.ss << "__syncthreads();" << std::endl; } else if (context_.scop().isWarpSyncId(stmtId)) { context_.ss << "__syncwarp();" << std::endl; + } else if (context_.scop().isGridSyncId(stmtId)) { + context_.ss << "__syncgrid();" << std::endl; } else if ( stmtId.get_name() == kReadIdName || stmtId.get_name() == kWriteIdName) { emitCopyStmt(statementContext); diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 6be0812fe..53102529b 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -882,6 +882,7 @@ std::unique_ptr makeSpecializedMappedScop( std::move(scop), grid, block, + mappedScop.useGridSync, mappedScop.unroll, mappedScop.useReadOnlyCache); res->insertMappingContext(); @@ -898,7 +899,7 @@ std::unique_ptr makeSpecializedMappedScop( // Before generating code, make a copy of the scop and insert // the context of the original scop as top-level // context node in schedule tree. -std::tuple MappedScop::codegen( +std::tuple MappedScop::codegen( const std::string& specializedName) const { validate(schedule()); @@ -906,7 +907,7 @@ std::tuple MappedScop::codegen( std::stringstream code; code << code::cpp::boundsAsTemplate << code::c::types << code::c::defines; - code << code::c::warpSyncFunctions; + code << code::c::warpSyncFunctions << code::c::gridSyncFunctions; code << std::endl; if (useReadOnlyCache) { code << code::cuda::ldg; @@ -922,7 +923,8 @@ std::tuple MappedScop::codegen( return std::make_tuple( code.str(), mappedScopForCodegen->numBlocks, - mappedScopForCodegen->numThreads); + mappedScopForCodegen->numThreads, + mappedScopForCodegen->useGridSync); } // Split out a single reduction tile (in the directions other than @@ -953,16 +955,101 @@ detail::ScheduleTree* MappedScop::splitOutReductionTileAndInsertSyncs( return child; } +namespace { +// Insert grid synchronizations where needed in st. +void insertGridSyncsDFS( + Scop& scop, + const std::vector& outerBands, + detail::ScheduleTree* st) { + using namespace polyhedral::detail; + // If the root of the schedule tree is an outermost coincident band, there is + // no synchronization left. + if (std::find(outerBands.begin(), outerBands.end(), st) != outerBands.end()) { + return; + } + auto nChildren = st->numChildren(); + auto children = st->children(); + for (size_t i = 0; i < nChildren; ++i) { + insertGridSyncsDFS(scop, outerBands, children[i]); + } + + // Insert synchronizations in sequences. + if (st->elemAs()) { + CHECK(nChildren); + if (hasOuterSequentialMember(scop.scheduleRoot(), st)) { + scop.insertSync(st, nChildren, Scop::SyncLevel::Grid); + } + for (size_t i = nChildren - 1; i > 0; --i) { + scop.insertSync(st, i, Scop::SyncLevel::Grid); + } + } + + // Insert synchronizations after sequential loops. + if (st->elemAs()) { + scop.insertSyncAfter(st, Scop::SyncLevel::Grid); + } +} +} // namespace + +void MappedScop::insertGridSyncs( + const std::vector& outerBands) { + using namespace polyhedral::detail; + + insertGridSyncsDFS(*scop_, outerBands, scop_->scheduleRoot()); +} + std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( std::unique_ptr&& scopUPtr, const CudaMappingOptions& cudaOptions) { using namespace polyhedral::detail; + // Query the relevant cuda device information + auto nbBlocks = 1; + for (size_t i = 0; i < cudaOptions.grid.size(); i++) { + nbBlocks *= cudaOptions.grid[i]; + } + auto nbThreads = 1; + for (size_t i = 0; i < cudaOptions.block.size(); i++) { + nbThreads *= cudaOptions.block[i]; + } + size_t sharedMemorySize = cudaOptions.proto().has_max_shared_memory() + ? cudaOptions.proto().max_shared_memory() + : querySharedMemorySize(); + size_t maxBlocksPerSM = cudaOptions.proto().has_max_blocks_per_sm() + ? cudaOptions.proto().max_blocks_per_sm() + : queryBlocksPerSM(); + size_t maxThreadsPerSM = cudaOptions.proto().has_max_threads_per_sm() + ? cudaOptions.proto().max_threads_per_sm() + : queryThreadsPerSM(); + size_t sharedMemorySizePerSM = + cudaOptions.proto().has_max_shared_memory_per_sm() + ? cudaOptions.proto().max_shared_memory_per_sm() + : querySharedMemorySizePerSM(); + size_t nbOfSM = cudaOptions.proto().has_nb_of_sm() + ? cudaOptions.proto().nb_of_sm() + : queryNbOfSM(); + auto blocksPerSM = nbOfSM == 0 ? 0 : ((nbBlocks + nbOfSM - 1) / nbOfSM); + auto threadsPerSM = nbThreads * blocksPerSM; + + bool useGridSync = FLAGS_grid_sync; + if (useGridSync) { + useGridSync &= nbOfSM * maxBlocksPerSM >= nbBlocks; + useGridSync &= maxThreadsPerSM > threadsPerSM; + } + + if (useGridSync) { + LOG(WARNING) << "Use grid sync" << std::endl; + } + if (FLAGS_grid_sync && !useGridSync) { + LOG(WARNING) << "Can't use grid sync" << std::endl; + } + const auto& generic = cudaOptions.generic; auto mappedScop = std::unique_ptr(new MappedScop( std::move(scopUPtr), ::tc::Grid(cudaOptions.grid), ::tc::Block(cudaOptions.block), + useGridSync, generic.proto.unroll(), cudaOptions.proto().use_readonly_cache())); auto& scop = mappedScop->scop_; @@ -975,20 +1062,32 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( // 2. Schedule scop = Scop::makeScheduled(*scop, generic.outerScheduleOptions); - // 3. Tile + // 3. Find and Tile outermost coincident bands... TC_CHECK_LT(0u, generic.tiling.size()) << "Must pass tile vector with >= 1 tile sizes"; - auto outerBand = scop->tileOuterBand(generic.tiling); + std::vector tiledBands; + if (useGridSync) { + tiledBands = scop->tileOuterCoincidentBands(generic.tiling); + sharedMemorySize = std::min( + sharedMemorySize, + blocksPerSM == 0 ? 0 : sharedMemorySizePerSM / blocksPerSM); + } else { + tiledBands = {scop->tileOuterBand(generic.tiling)}; + } // 4. Optionally reschedule if point loops need a different strategy than // tile loops - if (generic.outerScheduleOptions != generic.intraTileScheduleOptions) { - scop->reschedule(outerBand->child({0}), generic.intraTileScheduleOptions); - LOG_IF(INFO, FLAGS_debug_tc_mapper) - << "After intra-tile rescheduling:" << std::endl - << *mappedScop->schedule(); + for (auto outerBand : tiledBands) { + if (generic.outerScheduleOptions != generic.intraTileScheduleOptions && + outerBand->numChildren() != 0) { + scop->reschedule(outerBand->child({0}), generic.intraTileScheduleOptions); + } } + LOG_IF(INFO, FLAGS_debug_tc_mapper) + << "After intra-tile rescheduling:" << std::endl + << *mappedScop->schedule(); + // 1b. ...or after rescheduling if (!generic.proto.fix_parameters_before_scheduling()) { scop->specializeToContext(); @@ -998,34 +1097,44 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( mappedScop->insertMappingContext(); // 6. Map to threads - if (outerBand->numChildren() > 0) { - TC_CHECK_EQ(1u, outerBand->numChildren()); - // 6.1. Optionally detect reductions while mapping to threads + for (auto outerBand : tiledBands) { + if (outerBand->numChildren() > 0) { + TC_CHECK_EQ(1u, outerBand->numChildren()); - if (generic.proto.match_library_calls()) { - mappedScop->detectReductions(outerBand->child({0})); + // 6.1. Optionally detect reductions while mapping to threads + if (generic.proto.match_library_calls()) { + mappedScop->detectReductions(outerBand->child({0})); + } + auto child = outerBand->child({0}); + size_t numMappedInnerThreads = + mappedScop->mapInnermostBandsToThreads(child); + fixThreadsBelow(*mappedScop, outerBand, numMappedInnerThreads); } - auto child = outerBand->child({0}); - size_t numMappedInnerThreads = - mappedScop->mapInnermostBandsToThreads(child); - fixThreadsBelow(*mappedScop, outerBand, numMappedInnerThreads); - LOG_IF(INFO, FLAGS_debug_tc_mapper) - << "After mapping to threads:" << std::endl - << *mappedScop->schedule(); } + LOG_IF(INFO, FLAGS_debug_tc_mapper) + << "After mapping to threads:" << std::endl + << *mappedScop->schedule(); + // 7. Map to blocks - mappedScop->mapToBlocksAndScaleBand( - outerBand, generic.tiling.extractVector()); + for (auto outerBand : tiledBands) { + mappedScop->mapToBlocksAndScaleBand( + outerBand, generic.tiling.extractVector()); + } + LOG_IF(INFO, FLAGS_debug_tc_mapper) << "After mapping to blocks:" << std::endl << *mappedScop->schedule(); - // 8. Promote to shared memory below the loops mapped to blocks. - // This may split the outer band, so find the new outer band after promotion. + // 8. Insert grid synchronization where needed. + mappedScop->insertGridSyncs(tiledBands); + LOG_IF(INFO, FLAGS_debug_tc_mapper) + << "After inserting grid synchronization:" << std::endl + << *mappedScop->schedule(); + + // 9. Promote to shared memory below the loops mapped to blocks. + // This may split the outer band, so find the new outer band after + // promotion. if (cudaOptions.proto().use_shared_memory()) { - size_t sharedMemorySize = cudaOptions.proto().has_max_shared_memory() - ? cudaOptions.proto().max_shared_memory() - : querySharedMemorySize(); // If reductions found, their synchronization requires an opaque cache in // shared memory. Subtract 4k from available shared memory for each // reduction found, this is hack based on each thread of max 1024 in the @@ -1040,33 +1149,50 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( sharedMemorySize -= reductionMemoryRequirement; } - auto band = outerBand->elemAs(); - LOG_IF(WARNING, FLAGS_debug_tc_mapper && band->nMember() == 0) - << "Aborting memory promotion because outer band has 0 members (NYI)"; - if (band->nMember() > 0 && sharedMemorySize > 0) { + if (sharedMemorySize > 0) { LOG_IF( WARNING, cudaOptions.proto().unroll_copy_shared() && !generic.proto.has_unroll()) << "requested to unroll copies to shared memory without providing the unroll size"; + bool unroll = cudaOptions.proto().unroll_copy_shared() && + generic.proto.has_unroll(); + + std::vector depths; + std::vector bandsWithPromotion; + for (auto band : tiledBands) { + auto bandElem = band->elemAs(); + LOG_IF(WARNING, FLAGS_debug_tc_mapper && bandElem->nMember() == 0) + << "Aborting memory promotion for one band because it has 0 members (NYI)"; + if (bandElem->nMember() == 0) { + continue; + } + bandsWithPromotion.push_back(band); + auto depthBefore = band->scheduleDepth(scop->scheduleRoot()); + depths.push_back( + depthBefore + + std::min( + bandElem->nOuterCoincident(), + mappedScop->numBlocks.view.size())); + } - promoteGreedilyAtDepth( + sharedMemorySize = promoteGreedilyAtDepth( *mappedScop, - std::min(band->nOuterCoincident(), mappedScop->numBlocks.view.size()), + bandsWithPromotion, + depths, sharedMemorySize, cudaOptions.proto().unroll_copy_shared() && generic.proto.has_unroll()); - auto bands = ScheduleTree::collectDFSPreorder( - scop->scheduleRoot(), ScheduleTreeType::Band); - if (bands.size() == 0) { // Sanity check. + /*auto bands = ScheduleTree::collectDFSPreorder( + scop->scheduleRoot(), ScheduleTreeType::Band); + if (bands.size() == 0) { // Sanity check. throw NoBandsException("no bands after promotion"); - } - outerBand = bands[0]; + }*/ } } - // 9. Promote to registers below the loops mapped to threads. + // 10. Promote to registers below the loops mapped to threads. if (cudaOptions.proto().use_private_memory()) { promoteToRegistersBelowThreads(*mappedScop, -1ull); } diff --git a/tc/core/polyhedral/cuda/mapped_scop.h b/tc/core/polyhedral/cuda/mapped_scop.h index 55490596a..c8e845173 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.h +++ b/tc/core/polyhedral/cuda/mapped_scop.h @@ -61,11 +61,13 @@ class MappedScop { std::unique_ptr&& scop, ::tc::Grid grid, ::tc::Block block, + bool useGridSync_, uint64_t unroll_, bool useReadOnlyCache_) : scop_(std::move(scop)), numBlocks(grid), numThreads(block), + useGridSync(useGridSync_), unroll(unroll_), useReadOnlyCache(useReadOnlyCache_) {} @@ -73,7 +75,12 @@ class MappedScop { static inline std::unique_ptr makeOneBlockOneThread( std::unique_ptr&& scop) { auto mscop = std::unique_ptr(new MappedScop( - std::move(scop), ::tc::Grid{1, 1, 1}, ::tc::Block{1, 1, 1}, 1, false)); + std::move(scop), + ::tc::Grid{1, 1, 1}, + ::tc::Block{1, 1, 1}, + false, + 1, + false)); auto band = mscop->scop_->obtainOuterBand(); mscop->mapBlocksForward(band, 0); mscop->mapThreadsBackward(band); @@ -86,10 +93,11 @@ class MappedScop { std::unique_ptr&& scop, ::tc::Grid grid, ::tc::Block block, + bool useGridSync, uint64_t unroll, bool useReadOnlyCache) { - return std::unique_ptr( - new MappedScop(std::move(scop), grid, block, unroll, useReadOnlyCache)); + return std::unique_ptr(new MappedScop( + std::move(scop), grid, block, useGridSync, unroll, useReadOnlyCache)); } // Apply the hand-written OuterBlockInnerThread mapping strategy. @@ -121,7 +129,7 @@ class MappedScop { // Generate CUDA code at the current state of transformation provided a // name for the generated function. - std::tuple codegen( + std::tuple codegen( const std::string& specializedName) const; // Accessors.. @@ -215,6 +223,10 @@ class MappedScop { // Return a pointer to the split off tile. detail::ScheduleTree* splitOutReductionTileAndInsertSyncs( detail::ScheduleTree* band); + // Insert grid synchronization where needed. + // They are inserted in every sequence and below every sequential loop + // which are above the outer coincident bands. + void insertGridSyncs(const std::vector& outerBands); // Map "band" to thread identifiers using as many blockSizes values as outer // coincident dimensions (plus reduction dimension, if any), // insert synchronization in case of a reduction, and @@ -233,6 +245,7 @@ class MappedScop { public: const ::tc::Grid numBlocks; const ::tc::Block numThreads; + bool useGridSync; const uint64_t unroll; const bool useReadOnlyCache; diff --git a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc index 4804fdb04..8f6f35e65 100644 --- a/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc +++ b/tc/core/polyhedral/cuda/memory_promotion_heuristic.cc @@ -29,6 +29,7 @@ #include #include #include +#include namespace tc { namespace polyhedral { @@ -476,6 +477,13 @@ std::vector bandsSplitAfterDepth( return functional::Map(splitAtDepth, bands); } +struct IslIdSizeTHash { + size_t operator()(const std::pair& k) const { + return (isl::IslIdIslHash()(k.first) ^ (k.second << 1)); + } +}; +} // namespace + /* * For every place in the schedule tree where schedule depth (i.e., the number * of preceding band members) is "depth", promote tensor reference groups to @@ -487,15 +495,18 @@ std::vector bandsSplitAfterDepth( * Only promote if the tensor elements referenced by the group are reused or * accessed in a non-coalesced way. */ -void promoteToSharedGreedy( +size_t promoteToSharedGreedy( Scop& scop, const Block& block, - size_t depth, + std::vector trees, + std::vector depths, size_t maxMemory) { using namespace tc::polyhedral::detail; - if (depth == 0) { - throw promotion::PromotionNYI("promotion before any band"); + for (auto depth : depths) { + if (depth == 0) { + throw promotion::PromotionNYI("promotion before any band"); + } } auto root = scop.scheduleRoot(); @@ -503,8 +514,12 @@ void promoteToSharedGreedy( // 1. Collect all bands with a member located at the given depth in the // overall schedule. Make sure this is the last member of the band by // splitting off the subsequent members into a different band. - auto bands = bandsContainingScheduleDepth(root, depth); - bands = bandsSplitAfterDepth(bands, root, depth); + std::vector bands; + for (size_t i = 0; i < trees.size(); ++i) { + auto treeBands = bandsContainingScheduleDepth(trees[i], depths[i]); + treeBands = bandsSplitAfterDepth(treeBands, root, depths[i]); + bands.insert(bands.end(), treeBands.begin(), treeBands.end()); + } // 2. Compute full schedule without mapping filters. The filters would make // it impossible to test for coalescing by incrementing a member of a band as @@ -517,103 +532,127 @@ void promoteToSharedGreedy( // group either features reuse or is accessed in a non-coalesced way, or // both. size_t remainingMemory = maxMemory; - for (auto bandNode : bands) { - auto activePoints = activeDomainPoints(root, bandNode); - auto partialSched = partialSchedule(root, bandNode); - - auto groupMap = TensorReferenceGroup::accessedWithin( - partialSched.intersect_domain(activePoints), scop.reads, scop.writes); - // Pure affine schedule without (mapping) filters. - auto partialSchedMupa = partialScheduleMupa(root, bandNode); - - // Prepare groups for sorting, to have specified order necessary for - // reproducibility and tests. - using TensorGroupList = std::pair; - std::vector groupLists( - std::make_move_iterator(groupMap.begin()), - std::make_move_iterator(groupMap.end())); - - // Computes the total number of references in all groups. - auto refsCount = [](const TensorGroupsInfo& info) { - size_t refs = 0; - for (auto const& group : info) { - refs += group->referenceIds().size(); - } - return refs; - }; + std::vector partialScheds; + // Pure affine schedule without (mapping) filters. + std::vector partialSchedsMupa; + std::unordered_map< + std::pair, + TensorGroupsInfo, + IslIdSizeTHash> + groupMap; + + for (size_t i = 0; i < bands.size(); ++i) { + auto activePoints = activeDomainPoints(root, bands[i]); + auto groupMapOfSubtree = TensorReferenceGroup::accessedWithin(partialSched.intersect_domain(activePoints), scop.reads, scop.writes); + partialScheds.push_back(partialSchedule(root, bands[i])); + partialSchedsMupa.push_back(partialScheduleMupa(root, bands[i])); + + for (auto& tensorGroup : groupMapOfSubtree) { + groupMap[std::make_pair(tensorGroup.first, i)] = + std::move(tensorGroup.second); + } + } - // Sort by the total number of references, then by name. Because names are - // guarenteed to be unique, the order is total. + // Prepare groups for sorting, to have specified order necessary for + // reproducibility and tests. + using TensorGroupList = + std::pair, TensorGroupsInfo>; + std::vector groupLists( + std::make_move_iterator(groupMap.begin()), + std::make_move_iterator(groupMap.end())); + + // Computes the total number of references in all groups. + auto refsCount = [](const TensorGroupsInfo& info) { + size_t refs = 0; + for (auto const& group : info) { + refs += group->referenceIds().size(); + } + return refs; + }; + + // Sort by the total number of references, then by name. Because names are + // guarenteed to be unique, the order is total. + std::sort( + groupLists.begin(), + groupLists.end(), + [refsCount](const TensorGroupList& l1, const TensorGroupList& l2) { + auto r1 = refsCount(l1.second); + auto r2 = refsCount(l2.second); + auto n1 = l1.first.first.get_name(); + auto n2 = l2.first.first.get_name(); + return std::tie(r1, n1, l1.first.second) < + std::tie(r2, n2, l2.first.second); + }); + + for (auto& tensorGroups : groupLists) { + auto tensorId = tensorGroups.first.first; + auto bandNodeId = tensorGroups.first.second; + // Sort the reference groups to prioritize groups with more references as + // they are more likely to benefit from promotion. std::sort( - groupLists.begin(), - groupLists.end(), - [refsCount](const TensorGroupList& l1, const TensorGroupList& l2) { - auto r1 = refsCount(l1.second); - auto r2 = refsCount(l2.second); - return r1 == r2 ? l1.first.get_name() < l2.first.get_name() : r1 < r2; + tensorGroups.second.begin(), + tensorGroups.second.end(), + [refsCount]( + const std::unique_ptr& group1, + const std::unique_ptr& group2) { + return group1->referenceIds().size() > group2->referenceIds().size(); }); - for (auto& tensorGroups : groupLists) { - auto tensorId = tensorGroups.first; - // Sort the reference groups to prioritize groups with more references as - // they are more likely to benefit from promotion. - std::sort( - tensorGroups.second.begin(), - tensorGroups.second.end(), - [refsCount]( - const std::unique_ptr& group1, - const std::unique_ptr& group2) { - return group1->referenceIds().size() > - group2->referenceIds().size(); - }); - - for (auto& group : tensorGroups.second) { - auto sizes = group->approximationSizes(); - if (sizes.size() == 0) { - throw promotion::PromotionLogicError("cannot promote a scalar"); - } - if (sizes.back() % 2 == 0) { - sizes.back() += 1; - } - auto nApproximationElements = std::accumulate( - sizes.begin(), sizes.end(), 1, std::multiplies()); - size_t memoryRequirement = - nApproximationElements * scop.findArgument(tensorId).type().bytes(); - if (memoryRequirement > remainingMemory) { - continue; - } - // Do not promote if the group features no reuse and is accessed in a - // coalesced way. - if (!hasReuseWithin(*group, partialSchedMupa) && - !promotionImprovesCoalescing(root, bandNode, *group, fullSched)) { - continue; - } - scop.promoteGroup( - Scop::PromotedDecl::Kind::SharedMem, - tensorId, - std::move(group), - bandNode, - partialSched, - true); - remainingMemory -= memoryRequirement; + for (auto& group : tensorGroups.second) { + auto sizes = group->approximationSizes(); + if (sizes.size() == 0) { + throw promotion::PromotionLogicError("cannot promote a scalar"); + } + if (sizes.back() % 2 == 0) { + sizes.back() += 1; + } + auto nApproximationElements = std::accumulate( + sizes.begin(), sizes.end(), 1, std::multiplies()); + size_t memoryRequirement = + nApproximationElements * scop.findArgument(tensorId).type().bytes(); + if (memoryRequirement > remainingMemory) { + continue; + } + // Do not promote if the group features no reuse and is accessed in a + // coalesced way. + if (!hasReuseWithin(*group, partialSchedsMupa[bandNodeId]) && + !promotionImprovesCoalescing( + root, bands[bandNodeId], *group, fullSched)) { + continue; } + + scop.promoteGroup( + Scop::PromotedDecl::Kind::SharedMem, + tensorId, + std::move(group), + bands[bandNodeId], + partialScheds[bandNodeId], + true); + remainingMemory -= memoryRequirement; } + } + + for (auto bandNode : bands) { scop.insertSyncsAroundCopies(bandNode); } + return remainingMemory; } } // namespace -void promoteGreedilyAtDepth( +size_t promoteGreedilyAtDepth( MappedScop& mscop, - size_t depth, + std::vector trees, + std::vector depths, size_t sharedMemorySize, bool unrollCopies) { // 1. Promote using heuristic. - promoteToSharedGreedy( - mscop.scop(), mscop.numThreads, depth, sharedMemorySize); + sharedMemorySize = promoteToSharedGreedy( + mscop.scop(), mscop.numThreads, trees, depths, sharedMemorySize); // 2. Map copies to shared, state by copy mapCopiesToThreads(mscop, unrollCopies); + + return sharedMemorySize; } /* diff --git a/tc/core/polyhedral/cuda/memory_promotion_heuristic.h b/tc/core/polyhedral/cuda/memory_promotion_heuristic.h index 508a3d8f6..4fb08fd89 100644 --- a/tc/core/polyhedral/cuda/memory_promotion_heuristic.h +++ b/tc/core/polyhedral/cuda/memory_promotion_heuristic.h @@ -21,6 +21,9 @@ #include "tc/external/isl.h" namespace tc { + +class Block; + namespace polyhedral { class MappedScop; class Scop; @@ -36,9 +39,11 @@ class ScheduleTree; // "threadIdxXScheduleDepthState" contains the schedule depth at which the // computation was mapped to thread x and is used to check whether the global // memory is accessed in a coalesced way. -void promoteGreedilyAtDepth( +// Return the remaining memory. +size_t promoteGreedilyAtDepth( MappedScop& scop, - std::size_t depth, + std::vector trees, + std::vector depths, std::size_t sharedMemorySize, bool unrollCopies); diff --git a/tc/core/polyhedral/schedule_tree.cc b/tc/core/polyhedral/schedule_tree.cc index d087539d9..5610efd26 100644 --- a/tc/core/polyhedral/schedule_tree.cc +++ b/tc/core/polyhedral/schedule_tree.cc @@ -187,6 +187,9 @@ vector ScheduleTree::positionRelativeTo( } size_t ScheduleTree::scheduleDepth(const ScheduleTree* relativeRoot) const { + if (relativeRoot == this) { + return 0; + } size_t depth = 0; for (auto const& anc : ancestors(relativeRoot)) { auto bandElem = anc->elemAs(); diff --git a/tc/core/polyhedral/scop.cc b/tc/core/polyhedral/scop.cc index f7cc8bc76..c7829fee2 100644 --- a/tc/core/polyhedral/scop.cc +++ b/tc/core/polyhedral/scop.cc @@ -208,6 +208,9 @@ void Scop::promoteGroup( void Scop::insertSyncsAroundCopies(ScheduleTree* tree) { // Return immediately if nothing was inserted + if (tree->numChildren() == 0) { + return; + } auto extensionNode = tree->child({0})->elemAs(); if (!extensionNode) { @@ -453,6 +456,93 @@ detail::ScheduleTree* Scop::obtainOuterBand() { return tree; } +namespace { + +// Check if there is at least one band with a coincident dimension +// in the tree. +bool hasAtLeastOneCoincidentLoop(detail::ScheduleTree* tree) { + if (auto band = tree->elemAs()) { + if (find(band->coincident_.begin(), band->coincident_.end(), true) != + band->coincident_.end() && + band->permutable_) { + return true; + } + } + auto n = tree->numChildren(); + if (n == 0) { + return false; + } else if (n == 1) { + return hasAtLeastOneCoincidentLoop(tree->child({0})); + } else { + for (size_t i = 0; i < n; ++i) { + if (hasAtLeastOneCoincidentLoop(tree->child({i}))) { + return true; + } + } + } + return false; +} + +// Return the outermost coincident band. These are the bands with at least +// one coincident dimension, that are permutable, and that have no such +// coincident band as ancestors. Some of these bands are created with zero +// dimensions to ensure that the union of all children of all outermost +// coincident bands is equal to the leafs. Also, these created bands are +// put at the highest possible level of the tree, which also minimize the +// number of bands inserted. + +std::vector obtainOuterCoincidentBands( + detail::ScheduleTree* root, + detail::ScheduleTree* tree) { + auto nChildren = tree->numChildren(); + + // If there is no coincident loop, we create an outer band with no dimensions + if (not hasAtLeastOneCoincidentLoop(tree)) { + auto domain = root->elemAs(); + CHECK(domain); + auto band = ScheduleTree::makeEmptyBand(root); + if (nChildren == 0 || tree->elemAs() || + tree->elemAs()) { + return {setPermutable(insertNodeBelow(tree, std::move(band)))}; + } else { + return {setPermutable(insertNodeAbove(root, tree, std::move(band)))}; + } + } + + // If we find a coincident loop directly, we only have one outer band in tree. + while (nChildren == 1) { + if (auto band = tree->elemAs()) { + if (find(band->coincident_.begin(), band->coincident_.end(), true) != + band->coincident_.end() && + band->permutable_) { + return {tree}; + } + } + tree = tree->child({0}); + nChildren = tree->numChildren(); + } + + // If nChidren is null, that means that the current tree is a band with + // a coincident member. Otherwise, hasAtLeastOneCoincidentLoop would a + // return false + if (nChildren == 0) { + return {tree}; + } + + auto children = tree->children(); + std::vector outerBands = {}; + // Get the outer bands from the children, and add it to the + // list of outer bands. + for (size_t i = 0; i < nChildren; ++i) { + auto childOuterBand = obtainOuterCoincidentBands(root, tree->child({i})); + outerBands.insert( + outerBands.end(), childOuterBand.begin(), childOuterBand.end()); + } + return outerBands; +} + +} // namespace + detail::ScheduleTree* Scop::tileOuterBand(const TilingView& tileSizes) { using namespace tc::polyhedral::detail; auto band = obtainOuterBand(); @@ -467,6 +557,24 @@ detail::ScheduleTree* Scop::tileOuterBand(const TilingView& tileSizes) { return res; } +std::vector Scop::tileOuterCoincidentBands( + const TilingView& tileSizes) { + using namespace tc::polyhedral::detail; + auto bands = obtainOuterCoincidentBands(scheduleRoot(), scheduleRoot()); + std::vector tiledBands = {}; + for (auto band : bands) { + std::vector sizes = tileSizes.extractVector(); + auto bandNode = band->elemAs(); + if (bandNode->nMember() < sizes.size()) { + sizes.resize(bandNode->nMember()); + } + tiledBands.push_back(bandTile(band, sizes, TileOptions::ShiftPointLoops)); + } + LOG_IF(INFO, FLAGS_debug_tc_mapper) << "After tiling outer:" << std::endl + << *scheduleTreeUPtr; + return tiledBands; +} + void Scop::reschedule( ScheduleTree* tree, const SchedulerOptionsView& schedulerOptions) { @@ -483,8 +591,6 @@ void Scop::reschedule( auto newTree = computeSchedule(constraints, schedulerOptions); parentTree->detachChild(treePos); parentTree->insertChildren(treePos, newTree->detachChildren()); - LOG_IF(INFO, FLAGS_debug_tc_mapper) << "After rescheduling:" << std::endl - << *scheduleTreeUPtr; } const Halide::OutputImageParam& Scop::findArgument(isl::id id) const { diff --git a/tc/core/polyhedral/scop.h b/tc/core/polyhedral/scop.h index b305315d8..fd6f10e81 100644 --- a/tc/core/polyhedral/scop.h +++ b/tc/core/polyhedral/scop.h @@ -195,7 +195,7 @@ struct Scop { isl::id updateId); // The different level of synchronization. - enum class SyncLevel : int { None = 0, Warp = 1, Block = 2 }; + enum class SyncLevel : int { None = 0, Warp = 1, Block = 2, Grid = 3 }; // Given a sequence node in the schedule tree, insert // synchronization before the child at position "pos". @@ -228,6 +228,10 @@ struct Scop { static size_t count = 0; return count++; } + size_t gridSyncUID() const { + static size_t count = 0; + return count++; + } // Make the synchronization id corresponding to the synchronization level. // The level should not be None. @@ -239,6 +243,9 @@ struct Scop { case SyncLevel::Block: return makeSyncId(); break; + case SyncLevel::Grid: + return makeGridSyncId(); + break; default: TC_CHECK(level != SyncLevel::None); return isl::id(); @@ -256,6 +263,12 @@ struct Scop { ctx, std::string(kWarpSyncIdPrefix) + std::to_string(warpSyncUID())); } + isl::id makeGridSyncId() const { + auto ctx = domain().get_ctx(); + return isl::id( + ctx, std::string(kGridSyncIdPrefix) + std::to_string(gridSyncUID())); + } + // Check if the id has a name with the expected prefix, followed by a long // integer. static bool isIdWithExpectedPrefix( @@ -282,6 +295,10 @@ struct Scop { return isIdWithExpectedPrefix(id, kWarpSyncIdPrefix); } + static bool isGridSyncId(isl::id id) { + return isIdWithExpectedPrefix(id, kGridSyncIdPrefix); + } + static isl::id makeRefId(isl::ctx ctx) { static thread_local size_t count = 0; return isl::id(ctx, std::string("__tc_ref_") + std::to_string(count++)); @@ -404,6 +421,13 @@ struct Scop { // tile loop band. detail::ScheduleTree* tileOuterBand(const TilingView& tiling); + // Tile the outermost coincident bands. See the obtainOuterCoincidentBands + // function to see what these bands correspond to. + // Splits the bands into tile loop band and point loop band where point loops + // have fixed trop counts specified in "tiling", and returns a pointer to + // the tile loop bands. + std::vector tileOuterCoincidentBands( + const TilingView& tiling); // Reschedule the schedule subtree rooted at "tree" with the // given scheduler options. void reschedule( diff --git a/tc/proto/mapping_options.proto b/tc/proto/mapping_options.proto index ff29e3557..c413f1b8c 100644 --- a/tc/proto/mapping_options.proto +++ b/tc/proto/mapping_options.proto @@ -70,6 +70,21 @@ message CudaMappingOptionsProto { optional uint64 max_shared_memory = 7; // Use the readonly cache (i.e. emit __ldg loads) required bool use_readonly_cache = 8; + // Maximum number of blocks per streaming multiprocessor to use. If not + // provided, the number of blocks per sm on the current active device will + // be used. + optional uint64 max_blocks_per_sm = 9; + // Maximum number of threads per streaming multiprocessor to use. If not + // provided, the number of threads per sm on the current active device will + // be used. + optional uint64 max_threads_per_sm = 10; + // Maximum number of shared memory per streaming multiprocessor to use, in + // bytes. If not provided, all shared memory available on a sm on the current + // active device will be used. + optional uint64 max_shared_memory_per_sm = 11; + // The number of multiprocessors that can be used. If not provided, all + // multiprocessors will be used. + optional uint64 nb_of_sm = 12; } message CpuMappingOptionsProto { diff --git a/tc/version/CMakeLists.txt b/tc/version/CMakeLists.txt index cbf3ccfe2..31b400618 100644 --- a/tc/version/CMakeLists.txt +++ b/tc/version/CMakeLists.txt @@ -2,8 +2,18 @@ get_git_head_revision(GIT_REFSPEC GIT_HASH) execute_process(COMMAND ${GIT_EXECUTABLE} describe --all --long HEAD OUTPUT_VARIABLE GIT_TAG) SET(GIT_DESCRIPTION "${GIT_REFSPEC}${GIT_HASH}") configure_file(version.cc.in version.cc) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_library(tc_version STATIC ${CMAKE_CURRENT_BINARY_DIR}/version.cc) SET_TARGET_PROPERTIES(tc_version PROPERTIES LINKER_LANGUAGE CXX) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(tc_version INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) + +if (WITH_CUDA) + find_library(CUDA_LIBDEVRT_PATH libcudadevrt.a + HINTS ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs) + configure_file(cuda_version.cc.in cuda_version.cc) + add_library(tc_cuda_version STATIC ${CMAKE_CURRENT_BINARY_DIR}/cuda_version.cc) + SET_TARGET_PROPERTIES(tc_cuda_version PROPERTIES LINKER_LANGUAGE CXX) + target_include_directories(tc_cuda_version INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) +endif() diff --git a/tc/version/cuda_version.cc.in b/tc/version/cuda_version.cc.in new file mode 100644 index 000000000..c1915efb3 --- /dev/null +++ b/tc/version/cuda_version.cc.in @@ -0,0 +1,5 @@ +#include "cuda_version.h" + +namespace tc { +const char* cuda_libdevrt_path = "@CUDA_LIBDEVRT_PATH@"; +} diff --git a/tc/version/cuda_version.h b/tc/version/cuda_version.h new file mode 100644 index 000000000..00da80821 --- /dev/null +++ b/tc/version/cuda_version.h @@ -0,0 +1,19 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +namespace tc { +extern const char* cuda_libdevrt_path; +} diff --git a/test/caffe2/cuda/test_caffe2.cc b/test/caffe2/cuda/test_caffe2.cc index ec72b5db6..d33f1c332 100644 --- a/test/caffe2/cuda/test_caffe2.cc +++ b/test/caffe2/cuda/test_caffe2.cc @@ -499,6 +499,142 @@ def fun(float(B, N, M) X, float(B, M, K) Y) -> (Z) CheckEqual(w_ref, w_test, "Z", 1e-6); } +TEST_F(Caffe2Test, TcBatchNormNoGrid) { + static constexpr uint32_t N = 32, C = 4, H = 256, W = 256; + auto init_ws = [&](Workspace& w) { + auto AddInput = AddDeterministicallyRandomInput; + AddInput(w, {N, C, H, W}, "I"); + AddInput(w, {C}, "rMeanIn"); + AddInput(w, {C}, "rVarIn"); + }; + + auto tc = R"TC( + def batchnormnogrid(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) { + mean(c) +=! I(nn, c, hh, ww) + mean(c) = mean(c) / (N * H * W) + rMeanOut(c) = (1 - 0.2) * rMeanIn(c) + 0.2 * mean(c) + centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) + variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) + expectedVariance(c) +=! (variance(n, c, h, w) + 0.001) / (N * H * W) + rVarOut(c) = rsqrt((1 - 0.2) * rVarIn(c) + 0.2 * expectedVariance(c)) + O(n, c, h, w) = centered(n, c, h, w) / rVarOut(c) + normalizedOut(n, c, h, w) = O(n, c, h, w) + } +)TC"; + + Workspace w_test; + init_ws(w_test); + Argument tcArg = MakeArgument("tc_def", tc); + Argument tcNameArg = MakeArgument("tc_name", "batchnormnogrid"); + CudaMappingOptions options = + tc::makeBaseCliStrategy() + .outerScheduleFusionStrategy(tc::FusionStrategy::Max) + .outerScheduleAllowSkewing(false) + .outerSchedulePositiveOrthant(true) + .intraTileScheduleFusionStrategy(tc::FusionStrategy::Min) + .intraTileScheduleAllowSkewing(false) + .intraTileSchedulePositiveOrthant(true) + .tile(2, 1, 8, 1) + .unroll(1) + .tileImperfectlyNested(false) + .matchLibraryCalls(true) + .mapToThreads(256, 2) + .mapToBlocks(256) + .useSharedMemory(true) + .usePrivateMemory(false) + .unrollCopyShared(true); + Argument strategyArg = MakeArgument( + "mappingOptions", + tc::makeCliStrategy(options).toProtobufSerializedString()); + auto op_def = MakeOperatorDef( + "TcOp", + {"I", "rMeanIn", "rVarIn"}, + {"O", + "rMeanOut", + "rVarOut", + "mean", + "centered", + "variance", + "expectedVariance", + "normalizedOut"}, + {tcArg, tcNameArg, strategyArg}); + auto op = CreateOperator(op_def, &w_test); + ASSERT_TRUE(op.get()); + ASSERT_TRUE(op->Run()); + + TC_CUDA_RUNTIMEAPI_ENFORCE(cudaDeviceSynchronize()); + + // CheckEqual(w_ref, w_test, "Z", 1e-6); +} + +TEST_F(Caffe2Test, TcBatchNormGrid) { + static constexpr uint32_t N = 32, C = 4, H = 256, W = 256; + auto init_ws = [&](Workspace& w) { + auto AddInput = AddDeterministicallyRandomInput; + AddInput(w, {N, C, H, W}, "I"); + AddInput(w, {C}, "rMeanIn"); + AddInput(w, {C}, "rVarIn"); + }; + + auto tc = R"TC( + def batchnormgrid(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) { + mean(c) +=! I(nn, c, hh, ww) + mean(c) = mean(c) / (N * H * W) + rMeanOut(c) = (1 - 0.2) * rMeanIn(c) + 0.2 * mean(c) + centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) + variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) + expectedVariance(c) +=! (variance(n, c, h, w) + 0.001) / (N * H * W) + rVarOut(c) = rsqrt((1 - 0.2) * rVarIn(c) + 0.2 * expectedVariance(c)) + O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) + normalizedOut(n, c, h, w) = O(n, c, h, w) + } +)TC"; + + Workspace w_test; + init_ws(w_test); + Argument tcArg = MakeArgument("tc_def", tc); + Argument tcNameArg = MakeArgument("tc_name", "batchnormgrid"); + CudaMappingOptions options = + tc::makeBaseCliStrategy() + .outerScheduleFusionStrategy(tc::FusionStrategy::Preserve3Coincident) + .outerScheduleAllowSkewing(false) + .outerSchedulePositiveOrthant(true) + .intraTileScheduleFusionStrategy(tc::FusionStrategy::Min) + .intraTileScheduleAllowSkewing(false) + .intraTileSchedulePositiveOrthant(true) + .tile(2, 64) + .unroll(2) + .tileImperfectlyNested(false) + .matchLibraryCalls(true) + .mapToThreads(32, 2) + .mapToBlocks(128) + .useSharedMemory(false) + .usePrivateMemory(true) + .unrollCopyShared(false); + Argument strategyArg = MakeArgument( + "mappingOptions", + tc::makeCliStrategy(options).toProtobufSerializedString()); + auto op_def = MakeOperatorDef( + "TcOp", + {"I", "rMeanIn", "rVarIn"}, + {"O", + "rMeanOut", + "rVarOut", + "mean", + "centered", + "variance", + "expectedVariance", + "normalizedOut"}, + {tcArg, tcNameArg, strategyArg}); + auto op = CreateOperator(op_def, &w_test); + ASSERT_TRUE(op.get()); + ASSERT_TRUE(op->Run()); + + TC_CUDA_RUNTIMEAPI_ENFORCE(cudaDeviceSynchronize()); + + // CheckEqual(w_ref, w_test, "Z", 1e-6); +} + // TODO: TEST_F(Caffe2Test, DISABLED_TcGather) { auto init_ws = [&](Workspace& w) { diff --git a/test/cuda/CMakeLists.txt b/test/cuda/CMakeLists.txt index 564464388..72f106ffd 100644 --- a/test/cuda/CMakeLists.txt +++ b/test/cuda/CMakeLists.txt @@ -12,7 +12,7 @@ set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) ################################################################################ add_executable(test_basic_gpu test_basic_gpu.cc) add_test(test_basic_gpu test_basic_gpu) -target_link_libraries(test_basic_gpu ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARIES} ${CUDA_NVRTC_LIBRARIES} ${GOOGLE_LIBRARIES} pthread ) +target_link_libraries(test_basic_gpu ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARIES} ${CUDA_NVRTC_LIBRARIES} ${GOOGLE_LIBRARIES} pthread tc_cuda_version) ################################################################################ # Core GPU library tests diff --git a/test/cuda/test_autotuner.cc b/test/cuda/test_autotuner.cc index 71e36a30a..e03f994a2 100644 --- a/test/cuda/test_autotuner.cc +++ b/test/cuda/test_autotuner.cc @@ -178,6 +178,68 @@ def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) { benchmarkKernelOptions(TC, name, inputs, bestOptions); } +TEST_F(ATenCompilationUnitTest, BatchNorm) { + static constexpr uint32_t N = 32, C = 4, H = 256, W = 256; + at::Tensor I = at::CUDA(at::kFloat).rand({N, C, H, W}); + at::Tensor rMeanIn = at::CUDA(at::kFloat).rand({C}); + at::Tensor rVarIn = at::CUDA(at::kFloat).rand({C}); + at::Tensor eps = at::CUDA(at::kFloat).rand({1}); + eps[0] = 1.0f; + at::Tensor momentum = at::CUDA(at::kFloat).rand({1}); + momentum[0] = 1.0; + + std::vector inputs = {momentum, eps, I, rMeanIn, rVarIn}; + + static constexpr auto TC = R"TC( +def batchnorm( + float(1) momentum, float(1) eps, + float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) +-> (normalizedOut, rMeanOut, rVarOut, mean, centered, expectedVariance) +{ + mean(c) +=! I(r_n, c, r_h, r_w) + mean(c) = mean(c) / (N * H * W) + rMeanOut(c) = (1 - momentum(0)) * rMeanIn(c) + momentum(0) * mean(c) + centered(n, c, h, w) = I( n, c, h, w) - rMeanOut(c) + expectedVariance(c) +=! centered( n, c, h, w) * centered(n, c, h, w) + expectedVariance(c) = expectedVariance(c) / (N * H * W) + eps(0) + rVarOut(c) = rsqrt( + (1 - momentum(0)) * rVarIn(c) + + momentum(0) * expectedVariance(c)) + normalizedOut(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) + })TC"; + + auto name = "batchnorm"; + auto options = tc::CudaMappingOptions::makeNaiveMappingOptions(); + + std::string cacheFilename = ""; + auto bestOptions = autotune(TC, name, inputs, cacheFilename, options); +} + +TEST_F(ATenCompilationUnitTest, GroupConvolution) { + uint32_t N = 32, G = 32, C = 16, F = 16; + uint32_t W = 14, H = 14, KW = 3, KH = 3; + at::Tensor I = at::CUDA(at::kFloat).rand({N, G, C, H, W}); + at::Tensor W1 = at::CUDA(at::kFloat).rand({G, F, C, KH, KW}); + at::Tensor B = at::CUDA(at::kFloat).rand({G, F}); + std::vector inputs = {I, W1, B}; + + static constexpr auto TC = R"( +def group_convolution(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(G,F) B) +-> (O) +{ + O(n, g, f, h, w) +=! + I(n, g, r_c, h + r_kh, w + r_kw) * W1(g, f, r_c, r_kh, r_kw) + O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) +} +)"; + + auto name = "group_convolution"; + auto options = tc::CudaMappingOptions::makeNaiveMappingOptions(); + + std::string cacheFilename = ""; + auto bestOptions = autotune(TC, name, inputs, cacheFilename, options); +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true); diff --git a/test/cuda/test_basic_gpu.cc b/test/cuda/test_basic_gpu.cc index 4b856820e..e6815b29e 100644 --- a/test/cuda/test_basic_gpu.cc +++ b/test/cuda/test_basic_gpu.cc @@ -27,10 +27,12 @@ #include #include "tc/core/cuda/cuda.h" +#include "tc/version/cuda_version.h" std::vector jitCompile( std::string cuda, - std::vector extraCompileOptions = std::vector{}) { + std::vector extraCompileOptions = std::vector{}, + bool useGridSync = false) { // Actually do the compiling. nvrtcProgram prog; TC_NVRTC_CHECK( @@ -58,6 +60,9 @@ std::vector jitCompile( "-DNVRTC_CUB=1", "-lineinfo", cudaHome.c_str()}; + if (useGridSync) { + nvrtcts.push_back("--relocatable-device-code=true"); + } for (auto o : extraCompileOptions) { nvrtcts.push_back(o); } @@ -83,7 +88,7 @@ std::vector jitCompile( return PTX; } -void loadUnload(const std::string& ptx) { +void loadUnload(const std::string& ptx, bool useGridSync = false) { CUdevice cuDevice; CUcontext context; TC_CUDA_DRIVERAPI_ENFORCE(cuInit(0)); @@ -92,8 +97,28 @@ void loadUnload(const std::string& ptx) { CUmodule m; CUfunction k; - TC_CUDA_DRIVERAPI_ENFORCE(cuModuleLoadDataEx(&m, ptx.c_str(), 0, 0, 0)); - TC_CUDA_DRIVERAPI_ENFORCE(cuModuleGetFunction(&k, m, "foo")); + if (useGridSync) { + CUlinkState linkState; + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkCreate(0, 0, 0, &linkState)); + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkAddFile( + linkState, CU_JIT_INPUT_LIBRARY, tc::cuda_libdevrt_path, 0, 0, 0)); + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkAddData( + linkState, + CU_JIT_INPUT_PTX, + (void*)ptx.data(), + ptx.size(), + "device_code.ptx", + 0, + 0, + 0)); + size_t cubinSize; + void* cubin; + TC_CUDA_DRIVERAPI_ENFORCE(cuLinkComplete(linkState, &cubin, &cubinSize)); + TC_CUDA_DRIVERAPI_ENFORCE(cuModuleLoadData(&m, cubin)); + } else { + TC_CUDA_DRIVERAPI_ENFORCE(cuModuleLoadDataEx(&m, ptx.c_str(), 0, 0, 0)); + TC_CUDA_DRIVERAPI_ENFORCE(cuModuleGetFunction(&k, m, "foo")); + } TC_CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(m)); } @@ -113,6 +138,21 @@ __global__ void foo(int N) EXPECT_NE(std::string::npos, ptx.find(expected)); } +TEST(BasicGpuTest, NvrtcGridSync) { + auto PTX = jitCompile(R"CUDA( +extern "C" { +__global__ void foo(int N) +{ + assert(N == 1); +} +})CUDA", {"-G"}, true); + + std::string ptx(PTX.data()); + loadUnload(ptx, true); + auto expected = R"E(.visible .entry foo()E"; + EXPECT_NE(std::string::npos, ptx.find(expected)); +} + TEST(BasicGpuTest, CubReduce) { std::string path(CUB_HOME); std::string include = std::string("-I ") + path; @@ -149,6 +189,26 @@ __global__ void bar(float* o, const float* i) { EXPECT_NE(std::string::npos, ptx.find(expected)); } +TEST(BasicGpuTest, GridSync) { + auto PTX = jitCompile( + R"CUDA( +__device__ void __syncgrid() { + cudaCGSynchronize(cudaCGGetIntrinsicHandle(cudaCGScopeGrid),0); +} + +extern "C" { +__global__ void foo() { + __syncgrid(); +} +})CUDA", + {"-G"}, + true); + std::string ptx(PTX.data()); + loadUnload(ptx, true); + auto expected = R"E(.visible .entry foo()E"; + EXPECT_NE(std::string::npos, ptx.find(expected)); +} + namespace { // Mark the function argument as __restrict__ depending on the flag. std::string makeFuncWithOptionalRestrict(bool useRestrict) { diff --git a/test/test_cuda_mapper.cc b/test/test_cuda_mapper.cc index 3f43178ad..a632d8f9b 100644 --- a/test/test_cuda_mapper.cc +++ b/test/test_cuda_mapper.cc @@ -90,6 +90,7 @@ struct PolyhedralMapperTest : public ::testing::Test { std::move(scop), Grid{1}, Block{blockSizes[0], blockSizes[1]}, + false, 0, false); auto band = mscop->mapBlocksForward(root->child({0}), 1); @@ -113,6 +114,7 @@ struct PolyhedralMapperTest : public ::testing::Test { std::move(scop), Grid{gridSizes[0], gridSizes[1]}, Block{blockSizes[0], blockSizes[1]}, + false, 0, false); @@ -666,11 +668,7 @@ def fun() -> (O) { /* * Check that a schedule tree without a single outer band gets mapped - * properly, i.e., that a band is inserted above the branching. - * Use the minimal fusion strategy to ensure the scheduler produces - * an outer sequence. - * Check that no synchronizations are inserted, since there is no - * dependences between threads. + * on blocks, and has a grid synchronization. */ TEST_F(PolyhedralMapperTest, Copy2) { auto tc = R"TC( @@ -682,14 +680,9 @@ def fun(float(N) I) -> (O1, O2) { auto mappingOptions = DefaultOptions(); mappingOptions.scheduleFusionStrategy(FusionStrategy::Min); auto code = codegenMapped(tc, mappingOptions); - auto loop = "for (int c0 = t0; c0 < N; c0 += 32)"; - auto blockSync = "__syncthreads();"; - auto pos1 = code.find(loop); - auto pos2 = code.find(loop, pos1 + 1); - auto pos3 = code.find(blockSync); - EXPECT_TRUE(pos1 != std::string::npos); - EXPECT_TRUE(pos2 != std::string::npos); - EXPECT_TRUE(pos3 == std::string::npos); + auto gridSync = "__syncgrid()"; + auto pos = code.find(gridSync); + EXPECT_TRUE(pos != std::string::npos); } /* diff --git a/test/test_cuda_mapper_memory_promotion.cc b/test/test_cuda_mapper_memory_promotion.cc index ee55a10c5..1691d96ab 100644 --- a/test/test_cuda_mapper_memory_promotion.cc +++ b/test/test_cuda_mapper_memory_promotion.cc @@ -394,7 +394,12 @@ def fun(float(N, M) A) -> (B, C) { size_t maxSharedMemory) { auto mscop = prepareScop( tc, {{"N", problemSize1}, {"M", problemSize2}}, {tileSize1, tileSize2}); - promoteGreedilyAtDepth(*mscop, depth, maxSharedMemory, false); + promoteGreedilyAtDepth( + *mscop, + {mscop->scop().scheduleRoot()}, + {depth}, + maxSharedMemory, + false); return mscop; } }; diff --git a/third-party/islpp b/third-party/islpp index 12a5b72fb..0cf707e50 160000 --- a/third-party/islpp +++ b/third-party/islpp @@ -1 +1 @@ -Subproject commit 12a5b72fb7c5fe4a841b691287ce4233a30f4b0f +Subproject commit 0cf707e5003570c9df45a0b5b776429b5f0ce298