From 4da18a98033705383c838af3e5d1568864343008 Mon Sep 17 00:00:00 2001 From: math-fehr Date: Fri, 1 Jun 2018 18:19:07 +0200 Subject: [PATCH 1/2] Move relativeRange from unroll.cc to external/detail/islpp-inl.h --- tc/core/polyhedral/unroll.cc | 41 ---------------------------------- tc/external/detail/islpp-inl.h | 24 ++++++++++++++++++++ tc/external/detail/islpp.h | 23 +++++++++++++++++++ 3 files changed, 47 insertions(+), 41 deletions(-) diff --git a/tc/core/polyhedral/unroll.cc b/tc/core/polyhedral/unroll.cc index 7047bd0d1..9c483d43c 100644 --- a/tc/core/polyhedral/unroll.cc +++ b/tc/core/polyhedral/unroll.cc @@ -23,47 +23,6 @@ namespace tc { namespace polyhedral { namespace { -/* - * Return a bound on the range of values attained by "f" for fixed - * values of "fixed", taking into account basic strides - * in the range of values attained by "f". - * - * First construct a map from values of "fixed" to corresponding - * values of "f". If this map is empty, then "f" cannot attain - * any values and the bound is zero. - * Otherwise, consider pairs of "f" values for the same value - * of "fixed" and take their difference over all possible values - * of the parameters and of the "fixed" values. - * Take a simple overapproximation as a convex set and - * determine the stride is the value differences. - * The possibly quasi-affine set is then overapproximated by an affine set. - * At this point, the set is a possibly infinite, symmetrical interval. - * Take the maximal value of the difference divided by the stride plus one as - * a bound on the number of possible values of "f". - * That is, take M/s + 1. Note that 0 is always an element of - * the difference set, so no offset needs to be taken into account - * during the stride computation and M is an integer multiple of s. - */ -isl::val relativeRange(isl::union_map fixed, isl::union_pw_aff f) { - auto ctx = f.get_ctx(); - auto umap = isl::union_map::from(isl::multi_union_pw_aff(f)); - umap = umap.apply_domain(fixed); - if (umap.is_empty()) { - return isl::val::zero(ctx); - } - - umap = umap.range_product(umap); - umap = umap.range().unwrap(); - umap = umap.project_out_all_params(); - auto delta = isl::map::from_union_map(umap).deltas(); - auto hull = delta.simple_hull(); - auto stride = isl::set(hull).get_stride(0); - hull = isl::set(hull).polyhedral_hull(); - auto bound = hull.dim_max_val(0); - bound = bound.div(stride); - bound = bound.add(isl::val::one(ctx)); - return bound; -} /* * Compute a bound on the number of instances executed by "band" and diff --git a/tc/external/detail/islpp-inl.h b/tc/external/detail/islpp-inl.h index 38edad204..dcbd1c73d 100644 --- a/tc/external/detail/islpp-inl.h +++ b/tc/external/detail/islpp-inl.h @@ -215,4 +215,28 @@ inline isl::set operator&(isl::point P1, isl::set S2) { return S2 & P1; } +/////////////////////////////////////////////////////////////////////////////// +// Helper functions +/////////////////////////////////////////////////////////////////////////////// + +inline isl::val relativeRange(isl::union_map fixed, isl::union_pw_aff f) { + auto ctx = f.get_ctx(); + auto umap = isl::union_map::from(isl::multi_union_pw_aff(f)); + umap = umap.apply_domain(fixed); + if (umap.is_empty()) { + return isl::val::zero(ctx); + } + + umap = umap.range_product(umap); + umap = umap.range().unwrap(); + umap = umap.project_out_all_params(); + auto delta = isl::map::from_union_map(umap).deltas(); + auto hull = delta.simple_hull(); + auto stride = isl::set(hull).get_stride(0); + hull = isl::set(hull).polyhedral_hull(); + auto bound = hull.dim_max_val(0); + bound = bound.div(stride); + bound = bound.add(isl::val::one(ctx)); + return bound; +} } // namespace isl diff --git a/tc/external/detail/islpp.h b/tc/external/detail/islpp.h index affbd51ae..5546d8acc 100644 --- a/tc/external/detail/islpp.h +++ b/tc/external/detail/islpp.h @@ -347,6 +347,29 @@ inline isl::set makeSpecializationSet( return makeSpecializationSet(space, map); } +/* + * Return a bound on the range of values attained by "f" for fixed + * values of "fixed", taking into account basic strides + * in the range of values attained by "f". + * + * First construct a map from values of "fixed" to corresponding + * values of "f". If this map is empty, then "f" cannot attain + * any values and the bound is zero. + * Otherwise, consider pairs of "f" values for the same value + * of "fixed" and take their difference over all possible values + * of the parameters and of the "fixed" values. + * Take a simple overapproximation as a convex set and + * determine the stride is the value differences. + * The possibly quasi-affine set is then overapproximated by an affine set. + * At this point, the set is a possibly infinite, symmetrical interval. + * Take the maximal value of the difference divided by the stride plus one as + * a bound on the number of possible values of "f". + * That is, take M/s + 1. Note that 0 is always an element of + * the difference set, so no offset needs to be taken into account + * during the stride computation and M is an integer multiple of s. + */ +inline isl::val relativeRange(isl::union_map fixed, isl::union_pw_aff f); + namespace detail { // Helper class used to support range-based for loops on isl::*_list types. From 5a9f4fa8ebd9ab835a6b318402df71d447c44d3e Mon Sep 17 00:00:00 2001 From: math-fehr Date: Wed, 30 May 2018 18:20:21 +0200 Subject: [PATCH 2/2] Implements timeout for cuda using mapping option Also, added timeout flag to set default timeout to options. However, the flag does not work when the option is initialized before gflags. The flags set the timeout in ms. __timestamp() function is used to get the current timestamp(). This function should not be used according to nvidia, so it could have a different behavior in some devices. To have a timeout in CUDA, the blocks should first now the timestamp of the kernel launch. To do that, the firsts instruction of every block is to retrieve the timestamp stored in the global memory. If the value is 0 (it is at the start of the kernel), the block set the value to the current timestamp. All of that is done atomically. It might happend that the timestamp stored is not the lowest timestamp that blocks have computed, but it is close. After that, timeout checks are inserted in the kernel code, which checks if the kernel has ran more than n ns. The checks are inserted after some for loops and some sequences, and the checks are inserted such that between two checks, there is at least timeout_check_frequency iterations of for loops, to ensure that checks do not influence much the results. timeout_check_frequency can be modified by a flag. --- tc/core/compiler-inl.h | 6 +- tc/core/constants.h | 2 + tc/core/cpu/cpu_tc_executor.cc | 3 +- tc/core/cpu/cpu_tc_executor.h | 3 +- tc/core/cuda/cuda_mapping_options.cc | 8 +- tc/core/cuda/cuda_mapping_options.h | 5 + .../cuda/cuda_mapping_options_cpp_printer.cc | 3 + tc/core/cuda/cuda_rtc.cc | 27 ++- tc/core/cuda/cuda_rtc.h | 2 + tc/core/cuda/cuda_tc_executor.cc | 10 +- tc/core/cuda/cuda_tc_executor.h | 5 +- tc/core/flags.cc | 6 + tc/core/flags.h | 4 + tc/core/libraries.h | 7 + tc/core/polyhedral/cuda/codegen.cc | 45 ++++- tc/core/polyhedral/cuda/mapped_scop.cc | 160 +++++++++++++++++- tc/core/polyhedral/cuda/mapped_scop.h | 26 ++- tc/core/polyhedral/schedule_isl_conversion.cc | 4 +- tc/core/polyhedral/scop.h | 23 +++ tc/proto/mapping_options.proto | 2 + test/test_cuda_mapper.cc | 113 ++++++++----- 21 files changed, 394 insertions(+), 70 deletions(-) diff --git a/tc/core/compiler-inl.h b/tc/core/compiler-inl.h index e72ddfadb..616b6f993 100644 --- a/tc/core/compiler-inl.h +++ b/tc/core/compiler-inl.h @@ -66,6 +66,10 @@ std::unique_ptr compile( options); return std::unique_ptr( new typename Backend::ExecutorType( - inputsInfo, outputsInfo, halideComponents, compilationResult)); + inputsInfo, + outputsInfo, + halideComponents, + compilationResult, + options)); } } // namespace tc diff --git a/tc/core/constants.h b/tc/core/constants.h index 8ec277583..503878a4d 100644 --- a/tc/core/constants.h +++ b/tc/core/constants.h @@ -30,5 +30,7 @@ constexpr auto kWriteIdName = "write"; constexpr auto kSyncIdPrefix = "_sync_"; constexpr auto kWarpSyncIdPrefix = "_warpSync_"; +constexpr auto kTimeoutCheckPrefix = "_timeoutCheck_"; + } // namespace polyhedral } // namespace tc diff --git a/tc/core/cpu/cpu_tc_executor.cc b/tc/core/cpu/cpu_tc_executor.cc index a76efc714..924ecca0b 100644 --- a/tc/core/cpu/cpu_tc_executor.cc +++ b/tc/core/cpu/cpu_tc_executor.cc @@ -29,7 +29,8 @@ CpuTcExecutor::CpuTcExecutor( const std::vector& inputsInfo, const std::vector& outputsInfo, const tc2halide::HalideComponents& halideComponents, - const typename CpuBackend::CompilationResultType& compilationResult) + const typename CpuBackend::CompilationResultType& compilationResult, + const typename CpuBackend::MappingOptionsType& options) : TcExecutor( inputsInfo, outputsInfo, diff --git a/tc/core/cpu/cpu_tc_executor.h b/tc/core/cpu/cpu_tc_executor.h index 079ae1589..e335caaf6 100644 --- a/tc/core/cpu/cpu_tc_executor.h +++ b/tc/core/cpu/cpu_tc_executor.h @@ -30,7 +30,8 @@ class CpuTcExecutor : public TcExecutor { const std::vector& inputsInfo, const std::vector& outputsInfo, const tc2halide::HalideComponents& halideComponents, - const typename CpuBackend::CompilationResultType& compilationResult); + const typename CpuBackend::CompilationResultType& compilationResult, + const typename CpuBackend::MappingOptionsType& options); /// This is the "low-latency" mode in which we just propagate raw pointers to /// data in the address space where kernel is executed. diff --git a/tc/core/cuda/cuda_mapping_options.cc b/tc/core/cuda/cuda_mapping_options.cc index 826939488..a1b190b45 100644 --- a/tc/core/cuda/cuda_mapping_options.cc +++ b/tc/core/cuda/cuda_mapping_options.cc @@ -287,6 +287,11 @@ CudaMappingOptions& CudaMappingOptions::useReadOnlyCache(bool b) { return *this; } +CudaMappingOptions& CudaMappingOptions::timeout(uint32_t ms) { + ownedProto_.set_timeout(ms); + return *this; +} + CudaMappingOptions& CudaMappingOptions::mapToThreads( const std::string& commaSeparatedSizes) { auto sizes = parseCommaSeparatedIntegers(commaSeparatedSizes); @@ -318,7 +323,8 @@ CudaMappingOptions CudaMappingOptions::makeUnmappedMappingOptions() { .useSharedMemory(false) .usePrivateMemory(false) .unrollCopyShared(false) - .useReadOnlyCache(false); + .useReadOnlyCache(false) + .timeout(FLAGS_timeout); return mo; } diff --git a/tc/core/cuda/cuda_mapping_options.h b/tc/core/cuda/cuda_mapping_options.h index 8afb6e2b0..15cb0b459 100644 --- a/tc/core/cuda/cuda_mapping_options.h +++ b/tc/core/cuda/cuda_mapping_options.h @@ -197,6 +197,11 @@ class CudaMappingOptions { CudaMappingOptions& useReadOnlyCache(bool b); ///@} + /// Change kernel timeout + ///@{ + CudaMappingOptions& timeout(uint32_t ms); + ///@} + /// Static constructors for predefined strategies. ///@{ static CudaMappingOptions makeNaiveMappingOptions(); diff --git a/tc/core/cuda/cuda_mapping_options_cpp_printer.cc b/tc/core/cuda/cuda_mapping_options_cpp_printer.cc index 9e46367f6..a2ff884c5 100644 --- a/tc/core/cuda/cuda_mapping_options_cpp_printer.cc +++ b/tc/core/cuda/cuda_mapping_options_cpp_printer.cc @@ -38,6 +38,9 @@ CudaMappingOptionsCppPrinter& operator<<( prn.printValueOption( "maxSharedMemory", cudaOptions.proto().max_shared_memory()); } + if (cudaOptions.proto().has_timeout()) { + prn.printValueOption("timeout", cudaOptions.proto().timeout()); + } prn.endStmt(); return prn; } diff --git a/tc/core/cuda/cuda_rtc.cc b/tc/core/cuda/cuda_rtc.cc index b25c968a9..abae4d716 100644 --- a/tc/core/cuda/cuda_rtc.cc +++ b/tc/core/cuda/cuda_rtc.cc @@ -28,7 +28,10 @@ namespace tc { std::mutex nvrtc_mutex; -CudaRTCFunction::CudaRTCFunction() {} +CudaRTCFunction::CudaRTCFunction() { + TC_CUDA_RUNTIMEAPI_ENFORCE( + cudaMalloc((void**)&startTimeDev, sizeof(unsigned long long))); +} CudaRTCFunction::~CudaRTCFunction() { if (!cleared_) { @@ -43,6 +46,7 @@ void CudaRTCFunction::clear() { WithCudaDevice(kvp.first); TC_CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(kvp.second)); } + TC_CUDA_RUNTIMEAPI_ENFORCE(cudaFree((void*)startTimeDev)); cleared_ = true; } } @@ -136,7 +140,9 @@ Duration CudaRTCFunction::Launch( std::vector params, std::vector outputs, std::vector inputs, + uint32_t timeout, bool profile) const { + uint64_t timeoutInNs = timeout * 1000 * 1000; int dev; TC_CUDA_RUNTIMEAPI_ENFORCE(cudaGetDevice(&dev)); if (perGpuModule_.count(dev) == 0) { @@ -152,11 +158,19 @@ Duration CudaRTCFunction::Launch( constexpr size_t kNumMaxParameters = 100; std::array args_voidp{0}; - CHECK_GE(kNumMaxParameters, params.size() + outputs.size() + inputs.size()); + CHECK_GE( + kNumMaxParameters, + params.size() + outputs.size() + inputs.size() + (timeout != 0)); int ind = 0; for (auto& p : params) { args_voidp[ind++] = &p; } + if (timeout != 0) { + args_voidp[ind++] = + const_cast(static_cast(&startTimeDev)); + args_voidp[ind++] = + const_cast(static_cast(&timeoutInNs)); + } for (auto& o : outputs) { args_voidp[ind++] = &o; } @@ -171,6 +185,15 @@ Duration CudaRTCFunction::Launch( unsigned int bx = block[0]; unsigned int by = block[1]; unsigned int bz = block[2]; + if (timeout != 0) { + unsigned long long startTime = 0; + TC_CUDA_RUNTIMEAPI_ENFORCE(cudaMemcpy( + (void*)startTimeDev, + (void*)&startTime, + sizeof(unsigned long long), + cudaMemcpyHostToDevice)); + } + auto launch = [&]() { TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel( perGpuKernel_.at(dev), diff --git a/tc/core/cuda/cuda_rtc.h b/tc/core/cuda/cuda_rtc.h index 4aa5e831f..cfd7ba196 100644 --- a/tc/core/cuda/cuda_rtc.h +++ b/tc/core/cuda/cuda_rtc.h @@ -57,6 +57,7 @@ class CudaRTCFunction { std::vector params, std::vector outputs, std::vector inputs, + uint32_t timeout, bool profile = false) const; void clear(); @@ -64,6 +65,7 @@ class CudaRTCFunction { private: mutable std::unordered_map perGpuModule_; mutable std::unordered_map perGpuKernel_; + unsigned long long* startTimeDev; std::string specializedName; std::vector nvrtc_ptx; bool cleared_; diff --git a/tc/core/cuda/cuda_tc_executor.cc b/tc/core/cuda/cuda_tc_executor.cc index 72a1350ad..819b0c5aa 100644 --- a/tc/core/cuda/cuda_tc_executor.cc +++ b/tc/core/cuda/cuda_tc_executor.cc @@ -46,12 +46,14 @@ CudaTcExecutor::CudaTcExecutor( const std::vector& inputsInfo, const std::vector& outputsInfo, const tc2halide::HalideComponents& halideComponents, - const typename CudaBackend::CompilationResultType& compilationResult) + const typename CudaBackend::CompilationResultType& compilationResult, + const typename CudaBackend::MappingOptionsType& options) : TcExecutor( inputsInfo, outputsInfo, halideComponents, - compilationResult) { + compilationResult), + timeout_(options.proto().timeout()) { auto t0 = std::chrono::high_resolution_clock::now(); // force unloading in case we JIT with the same name/input/outputs with // different options. @@ -121,7 +123,8 @@ void CudaTcExecutor::uncheckedRun( info.stream, parameters_, outputs, - inputs); + inputs, + timeout_); } ProfilingInfo CudaTcExecutor::profileUnchecked( @@ -140,6 +143,7 @@ ProfilingInfo CudaTcExecutor::profileUnchecked( parameters_, outputs, inputs, + timeout_, true)); // The CPU overhead is the total time minus the (synchronized) kernel runtime Duration cpuOverhead(Duration::since(start)); diff --git a/tc/core/cuda/cuda_tc_executor.h b/tc/core/cuda/cuda_tc_executor.h index a497c940d..596ec8e70 100644 --- a/tc/core/cuda/cuda_tc_executor.h +++ b/tc/core/cuda/cuda_tc_executor.h @@ -30,7 +30,8 @@ class CudaTcExecutor : public TcExecutor { const std::vector& inputsInfo, const std::vector& outputsInfo, const tc2halide::HalideComponents& halideComponents, - const typename CudaBackend::CompilationResultType& compilationResult); + const typename CudaBackend::CompilationResultType& compilationResult, + const typename CudaBackend::MappingOptionsType& options); /// This is the "low-latency" mode in which we just propagate raw pointers to /// data in the address space where kernel is executed. @@ -63,5 +64,7 @@ class CudaTcExecutor : public TcExecutor { // GPU-specific results of compilation Grid grid_; Block block_; + + uint32_t timeout_; }; } // namespace tc diff --git a/tc/core/flags.cc b/tc/core/flags.cc index 80c9e5ec5..c79b99d6c 100644 --- a/tc/core/flags.cc +++ b/tc/core/flags.cc @@ -38,6 +38,12 @@ DEFINE_bool( DEFINE_bool(dump_cuda, false, "Print the generated source"); DEFINE_bool(dump_ptx, false, "Dump the generated PTX"); +DEFINE_uint32( + timeout_check_frequency, + 100, + "The minimum number of loop iterations between two timeout checks"); +DEFINE_uint32(timeout, 0, "The cuda kernel timeout in ms"); + // CPU codegen options DEFINE_bool(llvm_dump_before_opt, false, "Print IR before optimization"); DEFINE_bool(llvm_dump_after_opt, false, "Print IR after optimization"); diff --git a/tc/core/flags.h b/tc/core/flags.h index c748759a4..d69dd1294 100644 --- a/tc/core/flags.h +++ b/tc/core/flags.h @@ -31,6 +31,10 @@ DECLARE_bool(debug_tuner); DECLARE_bool(dump_cuda); DECLARE_bool(dump_ptx); +// Cuda timeout +DECLARE_uint32(timeout_check_frequency); +DECLARE_uint32(timeout); + // llvm codegen DECLARE_bool(llvm_dump_before_opt); DECLARE_bool(llvm_dump_after_opt); diff --git a/tc/core/libraries.h b/tc/core/libraries.h index 7605cf7f7..68e523f43 100644 --- a/tc/core/libraries.h +++ b/tc/core/libraries.h @@ -43,6 +43,13 @@ constexpr auto defines = R"C( #define inf __longlong_as_double(0x7ff0000000000000LL) )C"; +constexpr auto timestampFunction = R"C( +__device__ unsigned long long __timestamp() { + unsigned long long startns; + asm volatile("mov.u64 %0,%%globaltimer;" : "=l"(startns)); + return startns; +})C"; + constexpr auto warpSyncFunctions = R"C( // Before CUDA 9, syncwarp is a noop since warps are always synchronized. #if __CUDACC_VER_MAJOR__ < 9 diff --git a/tc/core/polyhedral/cuda/codegen.cc b/tc/core/polyhedral/cuda/codegen.cc index 9ceb740b5..34642b444 100644 --- a/tc/core/polyhedral/cuda/codegen.cc +++ b/tc/core/polyhedral/cuda/codegen.cc @@ -95,8 +95,9 @@ struct AstPrinter { bool inReduction_ = false; }; -vector emitParams(const Scop& scop) { +vector emitParams(const MappedScop& mappedScop) { vector res; + const auto& scop = mappedScop.scop(); res.reserve(scop.halide.params.size()); // Halide params. One of these two vectors will be empty. for (auto p : scop.halide.params) { @@ -104,6 +105,14 @@ vector emitParams(const Scop& scop) { ss << p.type() << " " << p.name(); res.push_back(ss.str()); } + if (mappedScop.useTimeout != 0) { + stringstream ssStartTime; + ssStartTime << "unsigned long long* startTime"; + res.push_back(ssStartTime.str()); + stringstream ssTimeout; + ssTimeout << "unsigned long long timeout"; + res.push_back(ssTimeout.str()); + } return res; } @@ -136,9 +145,10 @@ vector emitTypedTensorNames(const vector& tensors) { return res; } -void emitArgs(stringstream& ss, const Scop& scop) { +void emitArgs(stringstream& ss, const MappedScop& mappedScop) { // Order is: params, outs, ins - auto sigVec = emitParams(scop); + const auto& scop = mappedScop.scop(); + auto sigVec = emitParams(mappedScop); sigVec = sigVec + emitTypedTensorNames(scop.halide.outputs); sigVec = sigVec + emitTypedTensorNames(scop.halide.inputs); for (auto& s : sigVec) { @@ -152,10 +162,10 @@ void emitArgs(stringstream& ss, const Scop& scop) { void emitKernelSignature( stringstream& ss, const std::string& specializedName, - const Scop& scop) { + const MappedScop& mappedScop) { CHECK_NE(specializedName, "") << "name not provided"; ss << "__global__ void " << specializedName << "("; - emitArgs(ss, scop); + emitArgs(ss, mappedScop); ss << ") {" << endl; } @@ -452,6 +462,10 @@ void AstPrinter::emitStmt(isl::ast_node_user node) { } else if ( stmtId.get_name() == kReadIdName || stmtId.get_name() == kWriteIdName) { emitCopyStmt(statementContext); + } else if (context_.scop().isTimeoutCheckId(stmtId)) { + context_.ss << "if(__timestamp() - startns > timeout) {\n"; + context_.ss << ws.tab() << ws.tab() << "return;\n"; + context_.ss << ws.tab() << "}" << std::endl; } else { // regular statement auto mappedStmtId = statementContext.statementId(); CHECK_EQ(stmtId, mappedStmtId) @@ -668,6 +682,22 @@ void emitThreadIdInit(stringstream& ss, const MappedScop& scop) { ss << "int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;\n"; } +void emitTimestampInit(stringstream& ss) { + WS ws; + ss << ws.tab(); + ss << "unsigned long long startns = __timestamp();\n"; + ss << ws.tab(); + ss << "unsigned long long old_startns = startns;\n"; + ss << ws.tab(); + ss << "old_startns = atomicCAS(startTime, 0, startns);\n"; + ss << ws.tab(); + ss << "if(old_startns < startns && startns - old_startns > timeout && old_startns != 0) {\n"; + ss << ws.tab() << ws.tab(); + ss << "return;\n"; + ss << ws.tab(); + ss << "}\n"; +} + void emitTmpDecl(stringstream& ss, const Scop& scop) { for (const auto& kvp : scop.treeSyncUpdateMap) { WS ws; @@ -752,12 +782,15 @@ string emitCudaKernel( } stringstream ss; - emitKernelSignature(ss, specializedName, scop); + emitKernelSignature(ss, specializedName, mscop); emitThreadIdInit(ss, mscop); emitTensorViews(ss, scop.halide.outputs, paramValues); emitTensorViews(ss, scop.halide.inputs, paramValues); emitTmpDecl(ss, scop); emitPromotedArrayViewsHalide(ss, scop); + if (mscop.useTimeout) { + emitTimestampInit(ss); + } NodeInfoMapType nodeInfoMap; auto collect = [&nodeInfoMap]( isl::ast_node n, isl::ast_build b) -> isl::ast_node { diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 0304d269a..4f34e9d0b 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -894,14 +894,14 @@ std::unique_ptr makeSpecializedMappedScop( grid, block, mappedScop.unroll, - mappedScop.useReadOnlyCache); + mappedScop.useReadOnlyCache, + mappedScop.useTimeout); res->insertMappingContext(); LOG_IF(INFO, FLAGS_debug_tc_mapper) << "Codegen with tightened bounds [blocks:" << grid << ", threads:" << block << "] for tree:\n" << *res->schedule(); - return res; } } // namespace @@ -918,6 +918,9 @@ std::tuple MappedScop::codegen( std::stringstream code; code << code::cpp::boundsAsTemplate << code::c::types << code::c::defines; code << code::c::warpSyncFunctions; + if (useTimeout) { + code << code::c::timestampFunction; + } code << std::endl; if (mappedScopForCodegen->scop().treeSyncUpdateMap.size() != 0) { code << code::cuda::common; @@ -961,6 +964,151 @@ detail::ScheduleTree* MappedScop::splitOutReductionTileAndInsertSyncs( return child; } +namespace { + +// Computes a lower bound of the maximum number of loop iterations in the given +// schedule tree. Returns infinity if there is a parametric band. +isl::val boundInstances(detail::ScheduleTree* st, isl::union_map prefix) { + // If the tree is a leaf, there is only one iteration. + if (st->children().size() == 0) { + return isl::val::one(st->ctx_); + } + auto childrenPrefix = extendSchedule(st, prefix); + + // If the tree has multiple children, we take the maximum of the number + // of iterations of the children. + auto bound = isl::val::one(st->ctx_); + for (const auto& c : st->children()) { + auto childrenBound = boundInstances(c, childrenPrefix); + bound = bound.max(childrenBound); + } + + // If the tree is a band, we take the number of iterations of the child, + // multiplied by the number of iterations of the band. + if (auto band = st->elemAs()) { + if (bound.is_null()) { + bound = isl::val::one(st->ctx_); + } + auto partial = band->mupa_; + auto n = band->nMember(); + for (int i = n - 1; i >= 0; --i) { + auto member = partial.get_union_pw_aff(i); + auto outerMap = prefix; + if (i > 0) { + auto outer = partial.drop_dims(isl::dim_type::set, i, n - i); + outerMap = outerMap.flat_range_product(isl::union_map::from(outer)); + } + bound = bound.mul(relativeRange(outerMap, member)); + } + } + return bound; +} + +// Computes a lower bound of the maximum number of loop iterations in the given +// schedule tree. Returns infinity if there is a parametric band, or if timeout +// checks are inserted in the tree. +// Timeout checks are inserted after the innermosts sequences or loops +// containing more than timeoutCheckFrequency instances. +// extensionSet is the union of all extension sets that needs to be filtered +// out. This is needed because it is not allowed to add an extension in a tree +// which has an ancestor filtering a subset of the extension. +isl::val boundInstancesAndInsertTimeoutChecks( + MappedScop* mappedScop, + detail::ScheduleTree* st, + isl::union_map prefix, + isl::val timeoutCheckFrequency, + isl::union_set extensionSet) { + // If the tree is a leaf, there is only one iteration. + if (st->children().size() == 0) { + return isl::val::one(st->ctx_); + } + + auto childrenPrefix = extendSchedule(st, prefix); + + // If the tree is a filter, we check if it filters an ancestor extension. + // If it does, we don't insert timeout checks in the subtree. + auto isFilter = st->elemAs(); + if (isFilter && !isFilter->filter_.intersect(extensionSet).is_empty()) { + auto bound = boundInstances(st->child({0}), childrenPrefix); + } + + // If the tree is an extension, we unite the extension set with the + // previous ones. + if (auto extension = st->elemAs()) { + extensionSet = extension->extension_.range().unite(extensionSet); + } + + // For every children, compute a lower bound of their instance, and take + // the maximum value. + auto bound = isl::val::one(st->ctx_); + for (const auto& c : st->children()) { + auto childrenBound = boundInstancesAndInsertTimeoutChecks( + mappedScop, c, childrenPrefix, timeoutCheckFrequency, extensionSet); + bound = bound.max(childrenBound); + } + + // If there is already a timeout check in the children, there is no need to + // add more. + if (bound.is_infty()) { + return bound; + } + + // If the tree is a sequence and a timeout should be inserted, insert it + // at the end of the sequence. + if (bound.gt(timeoutCheckFrequency) && + st->elemAs()) { + mappedScop->scop().insertTimeoutCheck(st, st->numChildren()); + return isl::val::infty(st->ctx_); + } + + // If the tree is a band, check at every level if a timeout check should be + // inserted. Insert it if needed. + if (auto band = st->elemAs()) { + auto partial = band->mupa_; + auto n = band->nMember(); + + for (int i = n - 1; i >= 0; --i) { + auto member = partial.get_union_pw_aff(i); + auto outerMap = prefix; + if (i > 0) { + auto outer = partial.drop_dims(isl::dim_type::set, i, n - i); + outerMap = outerMap.flat_range_product(isl::union_map::from(outer)); + } + bound = bound.mul(relativeRange(outerMap, member)); + if (bound.gt(timeoutCheckFrequency)) { + if (i > 0) { + bandSplit(mappedScop->scop().scheduleRoot(), st, i); + mappedScop->scop().insertTimeoutCheckAfter(st->child({0})); + } else { + mappedScop->scop().insertTimeoutCheckAfter(st); + } + return isl::val::infty(st->ctx_); + } + } + } + return bound; +} + +} // namespace + +void MappedScop::insertTimeoutChecks( + detail::ScheduleTree* st, + unsigned timeoutCheckFrequency) { + using namespace polyhedral::detail; + CHECK_GT(timeoutCheckFrequency, 0); + + auto timeoutCheckFrequencyVal = isl::val(st->ctx_, timeoutCheckFrequency); + auto root = scop().scheduleRoot(); + auto domain = root->elemAs(); + auto prefix = prefixSchedule(root, root->child({0})); + boundInstancesAndInsertTimeoutChecks( + this, + st, + prefix, + timeoutCheckFrequencyVal, + isl::union_set::empty(domain->domain_.get_space().params())); +} + std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( std::unique_ptr&& scopUPtr, const CudaMappingOptions& cudaOptions) { @@ -972,7 +1120,8 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( ::tc::Grid(cudaOptions.grid), ::tc::Block(cudaOptions.block), generic.proto.unroll(), - cudaOptions.proto().use_readonly_cache())); + cudaOptions.proto().use_readonly_cache(), + cudaOptions.proto().timeout() != 0)); auto& scop = mappedScop->scop_; // 1a. Optionally specialize before scheduling... @@ -1083,6 +1232,11 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( << "After outerBlockInnerThread strategy:" << std::endl << *mappedScop->schedule(); + if (mappedScop->useTimeout) { + mappedScop->insertTimeoutChecks( + mappedScop->scop().scheduleRoot(), FLAGS_timeout_check_frequency); + } + return mappedScop; } diff --git a/tc/core/polyhedral/cuda/mapped_scop.h b/tc/core/polyhedral/cuda/mapped_scop.h index 169b4f138..0d0ee3417 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.h +++ b/tc/core/polyhedral/cuda/mapped_scop.h @@ -62,27 +62,35 @@ class MappedScop { ::tc::Grid grid, ::tc::Block block, uint64_t unroll_, - bool useReadOnlyCache_) + bool useReadOnlyCache_, + bool useTimeout_) : scop_(std::move(scop)), numBlocks(grid), numThreads(block), unroll(unroll_), - useReadOnlyCache(useReadOnlyCache_) {} + useReadOnlyCache(useReadOnlyCache_), + useTimeout(useTimeout_) {} public: static inline std::unique_ptr makeOneBlockOneThread( std::unique_ptr&& scop) { return std::unique_ptr(new MappedScop( - std::move(scop), ::tc::Grid{1, 1, 1}, ::tc::Block{1, 1, 1}, 1, false)); + std::move(scop), + ::tc::Grid{1, 1, 1}, + ::tc::Block{1, 1, 1}, + 1, + false, + false)); } static inline std::unique_ptr makeMappedScop( std::unique_ptr&& scop, ::tc::Grid grid, ::tc::Block block, uint64_t unroll, - bool useReadOnlyCache) { - return std::unique_ptr( - new MappedScop(std::move(scop), grid, block, unroll, useReadOnlyCache)); + bool useReadOnlyCache, + bool useTimeout) { + return std::unique_ptr(new MappedScop( + std::move(scop), grid, block, unroll, useReadOnlyCache, useTimeout)); } // Apply the hand-written OuterBlockInnerThread mapping strategy. @@ -194,6 +202,11 @@ class MappedScop { // Return a pointer to the split off tile. detail::ScheduleTree* splitOutReductionTileAndInsertSyncs( detail::ScheduleTree* band); + + void insertTimeoutChecks( + detail::ScheduleTree* st, + unsigned timeoutCheckFrequency); + // Map "band" to thread identifiers using as many blockSizes values as outer // coincident dimensions (plus reduction dimension, if any), // insert synchronization in case of a reduction, and @@ -214,6 +227,7 @@ class MappedScop { const ::tc::Block numThreads; const uint64_t unroll; const bool useReadOnlyCache; + const bool useTimeout; private: // Information about a detected reduction that can potentially diff --git a/tc/core/polyhedral/schedule_isl_conversion.cc b/tc/core/polyhedral/schedule_isl_conversion.cc index 0f4b2b43a..222201f91 100644 --- a/tc/core/polyhedral/schedule_isl_conversion.cc +++ b/tc/core/polyhedral/schedule_isl_conversion.cc @@ -247,7 +247,9 @@ isl::schedule toIslSchedule(const ScheduleTree* root) { CHECK(domain) << "Root node should be domain node" << *root; auto node = isl::schedule_node::from_domain(domain->domain_); node = extendChild(node, root); - return node.get_schedule(); + auto ans = node.get_schedule(); + return ans; + // return node.get_schedule(); } namespace { diff --git a/tc/core/polyhedral/scop.h b/tc/core/polyhedral/scop.h index e163dbe90..6fd048427 100644 --- a/tc/core/polyhedral/scop.h +++ b/tc/core/polyhedral/scop.h @@ -215,6 +215,14 @@ struct Scop { insertExtensionLabelAfter(scheduleRoot(), tree, makeSyncId(level)); } + void insertTimeoutCheckAfter(detail::ScheduleTree* tree) { + insertExtensionLabelAfter(scheduleRoot(), tree, makeTimeoutCheckId()); + } + + void insertTimeoutCheck(detail::ScheduleTree* seqNode, size_t pos) { + insertExtensionLabelAt(scheduleRoot(), seqNode, pos, makeTimeoutCheckId()); + } + size_t reductionUID() const { static size_t count = 0; return count++; @@ -227,6 +235,10 @@ struct Scop { static size_t count = 0; return count++; } + size_t timeoutCheckUID() const { + static size_t count = 0; + return count++; + } // Make the synchronization id corresponding to the synchronization level. // The level should not be None. @@ -255,6 +267,13 @@ struct Scop { ctx, std::string(kWarpSyncIdPrefix) + std::to_string(warpSyncUID())); } + isl::id makeTimeoutCheckId() const { + auto ctx = domain().get_ctx(); + return isl::id( + ctx, + std::string(kTimeoutCheckPrefix) + std::to_string(timeoutCheckUID())); + } + // Check if the id has a name with the expected prefix, followed by a long // integer. static bool isIdWithExpectedPrefix( @@ -281,6 +300,10 @@ struct Scop { return isIdWithExpectedPrefix(id, kWarpSyncIdPrefix); } + static bool isTimeoutCheckId(isl::id id) { + return isIdWithExpectedPrefix(id, kTimeoutCheckPrefix); + } + static isl::id makeRefId(isl::ctx ctx) { static thread_local size_t count = 0; return isl::id(ctx, std::string("__tc_ref_") + std::to_string(count++)); diff --git a/tc/proto/mapping_options.proto b/tc/proto/mapping_options.proto index ff29e3557..72e36d708 100644 --- a/tc/proto/mapping_options.proto +++ b/tc/proto/mapping_options.proto @@ -70,6 +70,8 @@ message CudaMappingOptionsProto { optional uint64 max_shared_memory = 7; // Use the readonly cache (i.e. emit __ldg loads) required bool use_readonly_cache = 8; + /// If provided, generate timeout checks. The given timeout is in ms. + optional uint32 timeout = 9; } message CpuMappingOptionsProto { diff --git a/test/test_cuda_mapper.cc b/test/test_cuda_mapper.cc index f33371cfb..38c40f4ed 100644 --- a/test/test_cuda_mapper.cc +++ b/test/test_cuda_mapper.cc @@ -91,6 +91,7 @@ struct PolyhedralMapperTest : public ::testing::Test { Grid{1}, Block{blockSizes[0], blockSizes[1]}, 0, + false, false); auto band = mscop->mapBlocksForward(root->child({0}), 1); bandScale(band, tileSizes); @@ -114,6 +115,7 @@ struct PolyhedralMapperTest : public ::testing::Test { Grid{gridSizes[0], gridSizes[1]}, Block{blockSizes[0], blockSizes[1]}, 0, + false, false); // Map to blocks @@ -364,13 +366,15 @@ def fun(float(N, M) A, float(N, M) B) -> (C) { auto res = mscop->codegen(specializedName); - std::string expected( + std::string expectedIds( R"RES(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; - int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; - float32 (*C)[M] = reinterpret_cast(pC); + int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;)RES"); + std::string expectedDeclarations( + R"RES(float32 (*C)[M] = reinterpret_cast(pC); const float32 (*A)[M] = reinterpret_cast(pA); - const float32 (*B)[M] = reinterpret_cast(pB); - for (int c1 = 16 * b1; c1 < M; c1 += 4096) { + const float32 (*B)[M] = reinterpret_cast(pB);)RES"); + std::string expectedCompute( + R"RES(for (int c1 = 16 * b1; c1 < M; c1 += 4096) { if (M >= t0 + c1 + 1) { C[(t1 + 16 * b0)][(t0 + c1)] = (A[(t1 + 16 * b0)][(t0 + c1)] + B[(t1 + 16 * b0)][(t0 + c1)]); } @@ -378,7 +382,11 @@ def fun(float(N, M) A, float(N, M) B) -> (C) { } )RES"); - ASSERT_NE(std::string::npos, std::get<0>(res).find(expected)) + ASSERT_NE(std::string::npos, std::get<0>(res).find(expectedIds)) + << std::get<0>(res); + ASSERT_NE(std::string::npos, std::get<0>(res).find(expectedDeclarations)) + << std::get<0>(res); + ASSERT_NE(std::string::npos, std::get<0>(res).find(expectedCompute)) << std::get<0>(res); ASSERT_EQ(32u, std::get<1>(res).view[0]) << "Improper dim in: " << std::get<1>(res).view; @@ -399,17 +407,21 @@ def fun(float(N, N, N, N) A, float(N, N) B, float(N, N) C, float(N, N) D) // Don't intersect context with the domain and see what happens auto res = std::get<0>(mscop->codegen(specializedName)); - std::string expected( + std::string expectedIds( R"RES(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; - int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; - float32 (*O1)[N] = reinterpret_cast(pO1); + int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;)RES"); + + std::string expectedDeclarations( + R"RES(float32 (*O1)[N] = reinterpret_cast(pO1); float32 (*O2)[N] = reinterpret_cast(pO2); float32 (*O3)[N] = reinterpret_cast(pO3); const float32 (*A)[N][N][N] = reinterpret_cast(pA); const float32 (*B)[N] = reinterpret_cast(pB); const float32 (*C)[N] = reinterpret_cast(pC); - const float32 (*D)[N] = reinterpret_cast(pD); - for (int c0 = 0; c0 < N; c0 += 1) { + const float32 (*D)[N] = reinterpret_cast(pD);)RES"); + + std::string expectedCompute( + R"RES(for (int c0 = 0; c0 < N; c0 += 1) { for (int c1 = 0; c1 < N; c1 += 1) { O1[c0][c1] = 0.000000f; } @@ -436,7 +448,9 @@ def fun(float(N, N, N, N) A, float(N, N) B, float(N, N) C, float(N, N) D) } )RES"); - ASSERT_NE(std::string::npos, res.find(expected)) << res; + ASSERT_NE(std::string::npos, res.find(expectedIds)) << res; + ASSERT_NE(std::string::npos, res.find(expectedDeclarations)) << res; + ASSERT_NE(std::string::npos, res.find(expectedCompute)) << res; } TEST_F(PolyhedralMapperTest, BareVariables) { @@ -450,13 +464,11 @@ def fun(float(N, N) A) -> (O) auto mscop = makeUnmapped(tc); auto res = std::get<0>(mscop->codegen(specializedName)); - string expected( - R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA) { - int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; - int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; - float32 (*O)[N] = reinterpret_cast(pO); - const float32 (*A)[N] = reinterpret_cast(pA); - for (int c0 = 0; c0 < N; c0 += 1) { + string expected1( + R"RES(float32 (*O)[N] = reinterpret_cast(pO); + const float32 (*A)[N] = reinterpret_cast(pA);)RES"); + string expected2( + R"RES(for (int c0 = 0; c0 < N; c0 += 1) { for (int c1 = 0; c1 < N; c1 += 1) { O[c0][c1] = (((A[c0][c1] + float32(c0)) + float32(c1)) + float32(N)); } @@ -464,7 +476,8 @@ def fun(float(N, N) A) -> (O) } )RES"); - ASSERT_NE(std::string::npos, res.find(expected)) << res; + ASSERT_NE(std::string::npos, res.find(expected1)) << res; + ASSERT_NE(std::string::npos, res.find(expected2)) << res; } TEST_F(PolyhedralMapperTest, CudaFunctions) { @@ -479,15 +492,18 @@ def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O) mscop->fixParameters({{"N", 512}}); auto res = std::get<0>(mscop->codegen(specializedName)); - string expected = - R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA, const float32* pB, const float32* pC) { - int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; - int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; - float32 (*O)[512] = reinterpret_cast(pO); + string expectedIds = + R"RES(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; + int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;)RES"; + + string expectedDeclarations = + R"RES(float32 (*O)[512] = reinterpret_cast(pO); const float32 (*A)[512] = reinterpret_cast(pA); const float32 (*B)[512] = reinterpret_cast(pB); - const float32 (*C) = reinterpret_cast(pC); - for (int c0 = 0; c0 <= 511; c0 += 1) { + const float32 (*C) = reinterpret_cast(pC);)RES"; + + string expectedCompute = + R"RES(for (int c0 = 0; c0 <= 511; c0 += 1) { for (int c1 = 0; c1 <= 511; c1 += 1) { O[c0][c1] = (nextafter(C[c0], exp(A[c0][c1])) + log(B[c1][c0])); } @@ -495,16 +511,32 @@ def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O) } )RES"; - ASSERT_NE(std::string::npos, res.find(expected)) << res; + ASSERT_NE(std::string::npos, res.find(expectedIds)) << res; + ASSERT_NE(std::string::npos, res.find(expectedDeclarations)) << res; + ASSERT_NE(std::string::npos, res.find(expectedCompute)) << res; } -constexpr auto kExpectedMatmul_64_64_64 = - R"CUDA(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; - int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; - float32 (*O)[64] = reinterpret_cast(pO); +TEST_F(PolyhedralMapperTest, MergedContexts) { + auto scop = PrepareAndJoinBandsMatMul(); + + // Unit test claims to use the specialized context properly + scop->fixParameters({{"M", 64}, {"N", 64}, {"K", 64}}); + scop->specializeToContext(); + + auto mscop = TileAndMapThreads(std::move(scop), {16, 16}, {32ul, 8ul}); + auto res = std::get<0>(mscop->codegen(specializedName)); + + string expectedIds = + R"CUDA(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; + int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;)CUDA"; + + string expectedDeclarations = + R"CUDA(float32 (*O)[64] = reinterpret_cast(pO); const float32 (*A)[64] = reinterpret_cast(pA); - const float32 (*B)[64] = reinterpret_cast(pB); - for (int c0 = 0; c0 <= 63; c0 += 16) { + const float32 (*B)[64] = reinterpret_cast(pB);)CUDA"; + + string expectedCompute = + R"CUDA(for (int c0 = 0; c0 <= 63; c0 += 16) { for (int c1 = 0; c1 <= 63; c1 += 16) { for (int c2 = t1; c2 <= 15; c2 += 8) { O[(c0 + c2)][(t0 + c1)] = 0.000000f; @@ -517,16 +549,9 @@ constexpr auto kExpectedMatmul_64_64_64 = } )CUDA"; -TEST_F(PolyhedralMapperTest, MergedContexts) { - auto scop = PrepareAndJoinBandsMatMul(); - - // Unit test claims to use the specialized context properly - scop->fixParameters({{"M", 64}, {"N", 64}, {"K", 64}}); - scop->specializeToContext(); - - auto mscop = TileAndMapThreads(std::move(scop), {16, 16}, {32ul, 8ul}); - auto res = std::get<0>(mscop->codegen(specializedName)); - ASSERT_TRUE(std::string::npos != res.find(kExpectedMatmul_64_64_64)) << res; + ASSERT_TRUE(std::string::npos != res.find(expectedIds)) << res; + ASSERT_TRUE(std::string::npos != res.find(expectedDeclarations)) << res; + ASSERT_TRUE(std::string::npos != res.find(expectedCompute)) << res; } TEST_F(PolyhedralMapperTest, Match1) {