diff --git a/tc/core/compiler-inl.h b/tc/core/compiler-inl.h index e72ddfadb..616b6f993 100644 --- a/tc/core/compiler-inl.h +++ b/tc/core/compiler-inl.h @@ -66,6 +66,10 @@ std::unique_ptr compile( options); return std::unique_ptr( new typename Backend::ExecutorType( - inputsInfo, outputsInfo, halideComponents, compilationResult)); + inputsInfo, + outputsInfo, + halideComponents, + compilationResult, + options)); } } // namespace tc diff --git a/tc/core/constants.h b/tc/core/constants.h index 8ec277583..503878a4d 100644 --- a/tc/core/constants.h +++ b/tc/core/constants.h @@ -30,5 +30,7 @@ constexpr auto kWriteIdName = "write"; constexpr auto kSyncIdPrefix = "_sync_"; constexpr auto kWarpSyncIdPrefix = "_warpSync_"; +constexpr auto kTimeoutCheckPrefix = "_timeoutCheck_"; + } // namespace polyhedral } // namespace tc diff --git a/tc/core/cpu/cpu_tc_executor.cc b/tc/core/cpu/cpu_tc_executor.cc index a76efc714..924ecca0b 100644 --- a/tc/core/cpu/cpu_tc_executor.cc +++ b/tc/core/cpu/cpu_tc_executor.cc @@ -29,7 +29,8 @@ CpuTcExecutor::CpuTcExecutor( const std::vector& inputsInfo, const std::vector& outputsInfo, const tc2halide::HalideComponents& halideComponents, - const typename CpuBackend::CompilationResultType& compilationResult) + const typename CpuBackend::CompilationResultType& compilationResult, + const typename CpuBackend::MappingOptionsType& options) : TcExecutor( inputsInfo, outputsInfo, diff --git a/tc/core/cpu/cpu_tc_executor.h b/tc/core/cpu/cpu_tc_executor.h index 079ae1589..e335caaf6 100644 --- a/tc/core/cpu/cpu_tc_executor.h +++ b/tc/core/cpu/cpu_tc_executor.h @@ -30,7 +30,8 @@ class CpuTcExecutor : public TcExecutor { const std::vector& inputsInfo, const std::vector& outputsInfo, const tc2halide::HalideComponents& halideComponents, - const typename CpuBackend::CompilationResultType& compilationResult); + const typename CpuBackend::CompilationResultType& compilationResult, + const typename CpuBackend::MappingOptionsType& options); /// This is the "low-latency" mode in which we just propagate raw pointers to /// data in the address space where kernel is executed. diff --git a/tc/core/cuda/cuda_mapping_options.cc b/tc/core/cuda/cuda_mapping_options.cc index 826939488..a1b190b45 100644 --- a/tc/core/cuda/cuda_mapping_options.cc +++ b/tc/core/cuda/cuda_mapping_options.cc @@ -287,6 +287,11 @@ CudaMappingOptions& CudaMappingOptions::useReadOnlyCache(bool b) { return *this; } +CudaMappingOptions& CudaMappingOptions::timeout(uint32_t ms) { + ownedProto_.set_timeout(ms); + return *this; +} + CudaMappingOptions& CudaMappingOptions::mapToThreads( const std::string& commaSeparatedSizes) { auto sizes = parseCommaSeparatedIntegers(commaSeparatedSizes); @@ -318,7 +323,8 @@ CudaMappingOptions CudaMappingOptions::makeUnmappedMappingOptions() { .useSharedMemory(false) .usePrivateMemory(false) .unrollCopyShared(false) - .useReadOnlyCache(false); + .useReadOnlyCache(false) + .timeout(FLAGS_timeout); return mo; } diff --git a/tc/core/cuda/cuda_mapping_options.h b/tc/core/cuda/cuda_mapping_options.h index 8afb6e2b0..15cb0b459 100644 --- a/tc/core/cuda/cuda_mapping_options.h +++ b/tc/core/cuda/cuda_mapping_options.h @@ -197,6 +197,11 @@ class CudaMappingOptions { CudaMappingOptions& useReadOnlyCache(bool b); ///@} + /// Change kernel timeout + ///@{ + CudaMappingOptions& timeout(uint32_t ms); + ///@} + /// Static constructors for predefined strategies. ///@{ static CudaMappingOptions makeNaiveMappingOptions(); diff --git a/tc/core/cuda/cuda_mapping_options_cpp_printer.cc b/tc/core/cuda/cuda_mapping_options_cpp_printer.cc index 9e46367f6..a2ff884c5 100644 --- a/tc/core/cuda/cuda_mapping_options_cpp_printer.cc +++ b/tc/core/cuda/cuda_mapping_options_cpp_printer.cc @@ -38,6 +38,9 @@ CudaMappingOptionsCppPrinter& operator<<( prn.printValueOption( "maxSharedMemory", cudaOptions.proto().max_shared_memory()); } + if (cudaOptions.proto().has_timeout()) { + prn.printValueOption("timeout", cudaOptions.proto().timeout()); + } prn.endStmt(); return prn; } diff --git a/tc/core/cuda/cuda_rtc.cc b/tc/core/cuda/cuda_rtc.cc index b25c968a9..abae4d716 100644 --- a/tc/core/cuda/cuda_rtc.cc +++ b/tc/core/cuda/cuda_rtc.cc @@ -28,7 +28,10 @@ namespace tc { std::mutex nvrtc_mutex; -CudaRTCFunction::CudaRTCFunction() {} +CudaRTCFunction::CudaRTCFunction() { + TC_CUDA_RUNTIMEAPI_ENFORCE( + cudaMalloc((void**)&startTimeDev, sizeof(unsigned long long))); +} CudaRTCFunction::~CudaRTCFunction() { if (!cleared_) { @@ -43,6 +46,7 @@ void CudaRTCFunction::clear() { WithCudaDevice(kvp.first); TC_CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(kvp.second)); } + TC_CUDA_RUNTIMEAPI_ENFORCE(cudaFree((void*)startTimeDev)); cleared_ = true; } } @@ -136,7 +140,9 @@ Duration CudaRTCFunction::Launch( std::vector params, std::vector outputs, std::vector inputs, + uint32_t timeout, bool profile) const { + uint64_t timeoutInNs = timeout * 1000 * 1000; int dev; TC_CUDA_RUNTIMEAPI_ENFORCE(cudaGetDevice(&dev)); if (perGpuModule_.count(dev) == 0) { @@ -152,11 +158,19 @@ Duration CudaRTCFunction::Launch( constexpr size_t kNumMaxParameters = 100; std::array args_voidp{0}; - CHECK_GE(kNumMaxParameters, params.size() + outputs.size() + inputs.size()); + CHECK_GE( + kNumMaxParameters, + params.size() + outputs.size() + inputs.size() + (timeout != 0)); int ind = 0; for (auto& p : params) { args_voidp[ind++] = &p; } + if (timeout != 0) { + args_voidp[ind++] = + const_cast(static_cast(&startTimeDev)); + args_voidp[ind++] = + const_cast(static_cast(&timeoutInNs)); + } for (auto& o : outputs) { args_voidp[ind++] = &o; } @@ -171,6 +185,15 @@ Duration CudaRTCFunction::Launch( unsigned int bx = block[0]; unsigned int by = block[1]; unsigned int bz = block[2]; + if (timeout != 0) { + unsigned long long startTime = 0; + TC_CUDA_RUNTIMEAPI_ENFORCE(cudaMemcpy( + (void*)startTimeDev, + (void*)&startTime, + sizeof(unsigned long long), + cudaMemcpyHostToDevice)); + } + auto launch = [&]() { TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel( perGpuKernel_.at(dev), diff --git a/tc/core/cuda/cuda_rtc.h b/tc/core/cuda/cuda_rtc.h index 4aa5e831f..cfd7ba196 100644 --- a/tc/core/cuda/cuda_rtc.h +++ b/tc/core/cuda/cuda_rtc.h @@ -57,6 +57,7 @@ class CudaRTCFunction { std::vector params, std::vector outputs, std::vector inputs, + uint32_t timeout, bool profile = false) const; void clear(); @@ -64,6 +65,7 @@ class CudaRTCFunction { private: mutable std::unordered_map perGpuModule_; mutable std::unordered_map perGpuKernel_; + unsigned long long* startTimeDev; std::string specializedName; std::vector nvrtc_ptx; bool cleared_; diff --git a/tc/core/cuda/cuda_tc_executor.cc b/tc/core/cuda/cuda_tc_executor.cc index 72a1350ad..819b0c5aa 100644 --- a/tc/core/cuda/cuda_tc_executor.cc +++ b/tc/core/cuda/cuda_tc_executor.cc @@ -46,12 +46,14 @@ CudaTcExecutor::CudaTcExecutor( const std::vector& inputsInfo, const std::vector& outputsInfo, const tc2halide::HalideComponents& halideComponents, - const typename CudaBackend::CompilationResultType& compilationResult) + const typename CudaBackend::CompilationResultType& compilationResult, + const typename CudaBackend::MappingOptionsType& options) : TcExecutor( inputsInfo, outputsInfo, halideComponents, - compilationResult) { + compilationResult), + timeout_(options.proto().timeout()) { auto t0 = std::chrono::high_resolution_clock::now(); // force unloading in case we JIT with the same name/input/outputs with // different options. @@ -121,7 +123,8 @@ void CudaTcExecutor::uncheckedRun( info.stream, parameters_, outputs, - inputs); + inputs, + timeout_); } ProfilingInfo CudaTcExecutor::profileUnchecked( @@ -140,6 +143,7 @@ ProfilingInfo CudaTcExecutor::profileUnchecked( parameters_, outputs, inputs, + timeout_, true)); // The CPU overhead is the total time minus the (synchronized) kernel runtime Duration cpuOverhead(Duration::since(start)); diff --git a/tc/core/cuda/cuda_tc_executor.h b/tc/core/cuda/cuda_tc_executor.h index a497c940d..596ec8e70 100644 --- a/tc/core/cuda/cuda_tc_executor.h +++ b/tc/core/cuda/cuda_tc_executor.h @@ -30,7 +30,8 @@ class CudaTcExecutor : public TcExecutor { const std::vector& inputsInfo, const std::vector& outputsInfo, const tc2halide::HalideComponents& halideComponents, - const typename CudaBackend::CompilationResultType& compilationResult); + const typename CudaBackend::CompilationResultType& compilationResult, + const typename CudaBackend::MappingOptionsType& options); /// This is the "low-latency" mode in which we just propagate raw pointers to /// data in the address space where kernel is executed. @@ -63,5 +64,7 @@ class CudaTcExecutor : public TcExecutor { // GPU-specific results of compilation Grid grid_; Block block_; + + uint32_t timeout_; }; } // namespace tc diff --git a/tc/core/flags.cc b/tc/core/flags.cc index 80c9e5ec5..c79b99d6c 100644 --- a/tc/core/flags.cc +++ b/tc/core/flags.cc @@ -38,6 +38,12 @@ DEFINE_bool( DEFINE_bool(dump_cuda, false, "Print the generated source"); DEFINE_bool(dump_ptx, false, "Dump the generated PTX"); +DEFINE_uint32( + timeout_check_frequency, + 100, + "The minimum number of loop iterations between two timeout checks"); +DEFINE_uint32(timeout, 0, "The cuda kernel timeout in ms"); + // CPU codegen options DEFINE_bool(llvm_dump_before_opt, false, "Print IR before optimization"); DEFINE_bool(llvm_dump_after_opt, false, "Print IR after optimization"); diff --git a/tc/core/flags.h b/tc/core/flags.h index c748759a4..d69dd1294 100644 --- a/tc/core/flags.h +++ b/tc/core/flags.h @@ -31,6 +31,10 @@ DECLARE_bool(debug_tuner); DECLARE_bool(dump_cuda); DECLARE_bool(dump_ptx); +// Cuda timeout +DECLARE_uint32(timeout_check_frequency); +DECLARE_uint32(timeout); + // llvm codegen DECLARE_bool(llvm_dump_before_opt); DECLARE_bool(llvm_dump_after_opt); diff --git a/tc/core/libraries.h b/tc/core/libraries.h index 7605cf7f7..68e523f43 100644 --- a/tc/core/libraries.h +++ b/tc/core/libraries.h @@ -43,6 +43,13 @@ constexpr auto defines = R"C( #define inf __longlong_as_double(0x7ff0000000000000LL) )C"; +constexpr auto timestampFunction = R"C( +__device__ unsigned long long __timestamp() { + unsigned long long startns; + asm volatile("mov.u64 %0,%%globaltimer;" : "=l"(startns)); + return startns; +})C"; + constexpr auto warpSyncFunctions = R"C( // Before CUDA 9, syncwarp is a noop since warps are always synchronized. #if __CUDACC_VER_MAJOR__ < 9 diff --git a/tc/core/polyhedral/cuda/codegen.cc b/tc/core/polyhedral/cuda/codegen.cc index 9ceb740b5..34642b444 100644 --- a/tc/core/polyhedral/cuda/codegen.cc +++ b/tc/core/polyhedral/cuda/codegen.cc @@ -95,8 +95,9 @@ struct AstPrinter { bool inReduction_ = false; }; -vector emitParams(const Scop& scop) { +vector emitParams(const MappedScop& mappedScop) { vector res; + const auto& scop = mappedScop.scop(); res.reserve(scop.halide.params.size()); // Halide params. One of these two vectors will be empty. for (auto p : scop.halide.params) { @@ -104,6 +105,14 @@ vector emitParams(const Scop& scop) { ss << p.type() << " " << p.name(); res.push_back(ss.str()); } + if (mappedScop.useTimeout != 0) { + stringstream ssStartTime; + ssStartTime << "unsigned long long* startTime"; + res.push_back(ssStartTime.str()); + stringstream ssTimeout; + ssTimeout << "unsigned long long timeout"; + res.push_back(ssTimeout.str()); + } return res; } @@ -136,9 +145,10 @@ vector emitTypedTensorNames(const vector& tensors) { return res; } -void emitArgs(stringstream& ss, const Scop& scop) { +void emitArgs(stringstream& ss, const MappedScop& mappedScop) { // Order is: params, outs, ins - auto sigVec = emitParams(scop); + const auto& scop = mappedScop.scop(); + auto sigVec = emitParams(mappedScop); sigVec = sigVec + emitTypedTensorNames(scop.halide.outputs); sigVec = sigVec + emitTypedTensorNames(scop.halide.inputs); for (auto& s : sigVec) { @@ -152,10 +162,10 @@ void emitArgs(stringstream& ss, const Scop& scop) { void emitKernelSignature( stringstream& ss, const std::string& specializedName, - const Scop& scop) { + const MappedScop& mappedScop) { CHECK_NE(specializedName, "") << "name not provided"; ss << "__global__ void " << specializedName << "("; - emitArgs(ss, scop); + emitArgs(ss, mappedScop); ss << ") {" << endl; } @@ -452,6 +462,10 @@ void AstPrinter::emitStmt(isl::ast_node_user node) { } else if ( stmtId.get_name() == kReadIdName || stmtId.get_name() == kWriteIdName) { emitCopyStmt(statementContext); + } else if (context_.scop().isTimeoutCheckId(stmtId)) { + context_.ss << "if(__timestamp() - startns > timeout) {\n"; + context_.ss << ws.tab() << ws.tab() << "return;\n"; + context_.ss << ws.tab() << "}" << std::endl; } else { // regular statement auto mappedStmtId = statementContext.statementId(); CHECK_EQ(stmtId, mappedStmtId) @@ -668,6 +682,22 @@ void emitThreadIdInit(stringstream& ss, const MappedScop& scop) { ss << "int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;\n"; } +void emitTimestampInit(stringstream& ss) { + WS ws; + ss << ws.tab(); + ss << "unsigned long long startns = __timestamp();\n"; + ss << ws.tab(); + ss << "unsigned long long old_startns = startns;\n"; + ss << ws.tab(); + ss << "old_startns = atomicCAS(startTime, 0, startns);\n"; + ss << ws.tab(); + ss << "if(old_startns < startns && startns - old_startns > timeout && old_startns != 0) {\n"; + ss << ws.tab() << ws.tab(); + ss << "return;\n"; + ss << ws.tab(); + ss << "}\n"; +} + void emitTmpDecl(stringstream& ss, const Scop& scop) { for (const auto& kvp : scop.treeSyncUpdateMap) { WS ws; @@ -752,12 +782,15 @@ string emitCudaKernel( } stringstream ss; - emitKernelSignature(ss, specializedName, scop); + emitKernelSignature(ss, specializedName, mscop); emitThreadIdInit(ss, mscop); emitTensorViews(ss, scop.halide.outputs, paramValues); emitTensorViews(ss, scop.halide.inputs, paramValues); emitTmpDecl(ss, scop); emitPromotedArrayViewsHalide(ss, scop); + if (mscop.useTimeout) { + emitTimestampInit(ss); + } NodeInfoMapType nodeInfoMap; auto collect = [&nodeInfoMap]( isl::ast_node n, isl::ast_build b) -> isl::ast_node { diff --git a/tc/core/polyhedral/cuda/mapped_scop.cc b/tc/core/polyhedral/cuda/mapped_scop.cc index 0304d269a..4f34e9d0b 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.cc +++ b/tc/core/polyhedral/cuda/mapped_scop.cc @@ -894,14 +894,14 @@ std::unique_ptr makeSpecializedMappedScop( grid, block, mappedScop.unroll, - mappedScop.useReadOnlyCache); + mappedScop.useReadOnlyCache, + mappedScop.useTimeout); res->insertMappingContext(); LOG_IF(INFO, FLAGS_debug_tc_mapper) << "Codegen with tightened bounds [blocks:" << grid << ", threads:" << block << "] for tree:\n" << *res->schedule(); - return res; } } // namespace @@ -918,6 +918,9 @@ std::tuple MappedScop::codegen( std::stringstream code; code << code::cpp::boundsAsTemplate << code::c::types << code::c::defines; code << code::c::warpSyncFunctions; + if (useTimeout) { + code << code::c::timestampFunction; + } code << std::endl; if (mappedScopForCodegen->scop().treeSyncUpdateMap.size() != 0) { code << code::cuda::common; @@ -961,6 +964,151 @@ detail::ScheduleTree* MappedScop::splitOutReductionTileAndInsertSyncs( return child; } +namespace { + +// Computes a lower bound of the maximum number of loop iterations in the given +// schedule tree. Returns infinity if there is a parametric band. +isl::val boundInstances(detail::ScheduleTree* st, isl::union_map prefix) { + // If the tree is a leaf, there is only one iteration. + if (st->children().size() == 0) { + return isl::val::one(st->ctx_); + } + auto childrenPrefix = extendSchedule(st, prefix); + + // If the tree has multiple children, we take the maximum of the number + // of iterations of the children. + auto bound = isl::val::one(st->ctx_); + for (const auto& c : st->children()) { + auto childrenBound = boundInstances(c, childrenPrefix); + bound = bound.max(childrenBound); + } + + // If the tree is a band, we take the number of iterations of the child, + // multiplied by the number of iterations of the band. + if (auto band = st->elemAs()) { + if (bound.is_null()) { + bound = isl::val::one(st->ctx_); + } + auto partial = band->mupa_; + auto n = band->nMember(); + for (int i = n - 1; i >= 0; --i) { + auto member = partial.get_union_pw_aff(i); + auto outerMap = prefix; + if (i > 0) { + auto outer = partial.drop_dims(isl::dim_type::set, i, n - i); + outerMap = outerMap.flat_range_product(isl::union_map::from(outer)); + } + bound = bound.mul(relativeRange(outerMap, member)); + } + } + return bound; +} + +// Computes a lower bound of the maximum number of loop iterations in the given +// schedule tree. Returns infinity if there is a parametric band, or if timeout +// checks are inserted in the tree. +// Timeout checks are inserted after the innermosts sequences or loops +// containing more than timeoutCheckFrequency instances. +// extensionSet is the union of all extension sets that needs to be filtered +// out. This is needed because it is not allowed to add an extension in a tree +// which has an ancestor filtering a subset of the extension. +isl::val boundInstancesAndInsertTimeoutChecks( + MappedScop* mappedScop, + detail::ScheduleTree* st, + isl::union_map prefix, + isl::val timeoutCheckFrequency, + isl::union_set extensionSet) { + // If the tree is a leaf, there is only one iteration. + if (st->children().size() == 0) { + return isl::val::one(st->ctx_); + } + + auto childrenPrefix = extendSchedule(st, prefix); + + // If the tree is a filter, we check if it filters an ancestor extension. + // If it does, we don't insert timeout checks in the subtree. + auto isFilter = st->elemAs(); + if (isFilter && !isFilter->filter_.intersect(extensionSet).is_empty()) { + auto bound = boundInstances(st->child({0}), childrenPrefix); + } + + // If the tree is an extension, we unite the extension set with the + // previous ones. + if (auto extension = st->elemAs()) { + extensionSet = extension->extension_.range().unite(extensionSet); + } + + // For every children, compute a lower bound of their instance, and take + // the maximum value. + auto bound = isl::val::one(st->ctx_); + for (const auto& c : st->children()) { + auto childrenBound = boundInstancesAndInsertTimeoutChecks( + mappedScop, c, childrenPrefix, timeoutCheckFrequency, extensionSet); + bound = bound.max(childrenBound); + } + + // If there is already a timeout check in the children, there is no need to + // add more. + if (bound.is_infty()) { + return bound; + } + + // If the tree is a sequence and a timeout should be inserted, insert it + // at the end of the sequence. + if (bound.gt(timeoutCheckFrequency) && + st->elemAs()) { + mappedScop->scop().insertTimeoutCheck(st, st->numChildren()); + return isl::val::infty(st->ctx_); + } + + // If the tree is a band, check at every level if a timeout check should be + // inserted. Insert it if needed. + if (auto band = st->elemAs()) { + auto partial = band->mupa_; + auto n = band->nMember(); + + for (int i = n - 1; i >= 0; --i) { + auto member = partial.get_union_pw_aff(i); + auto outerMap = prefix; + if (i > 0) { + auto outer = partial.drop_dims(isl::dim_type::set, i, n - i); + outerMap = outerMap.flat_range_product(isl::union_map::from(outer)); + } + bound = bound.mul(relativeRange(outerMap, member)); + if (bound.gt(timeoutCheckFrequency)) { + if (i > 0) { + bandSplit(mappedScop->scop().scheduleRoot(), st, i); + mappedScop->scop().insertTimeoutCheckAfter(st->child({0})); + } else { + mappedScop->scop().insertTimeoutCheckAfter(st); + } + return isl::val::infty(st->ctx_); + } + } + } + return bound; +} + +} // namespace + +void MappedScop::insertTimeoutChecks( + detail::ScheduleTree* st, + unsigned timeoutCheckFrequency) { + using namespace polyhedral::detail; + CHECK_GT(timeoutCheckFrequency, 0); + + auto timeoutCheckFrequencyVal = isl::val(st->ctx_, timeoutCheckFrequency); + auto root = scop().scheduleRoot(); + auto domain = root->elemAs(); + auto prefix = prefixSchedule(root, root->child({0})); + boundInstancesAndInsertTimeoutChecks( + this, + st, + prefix, + timeoutCheckFrequencyVal, + isl::union_set::empty(domain->domain_.get_space().params())); +} + std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( std::unique_ptr&& scopUPtr, const CudaMappingOptions& cudaOptions) { @@ -972,7 +1120,8 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( ::tc::Grid(cudaOptions.grid), ::tc::Block(cudaOptions.block), generic.proto.unroll(), - cudaOptions.proto().use_readonly_cache())); + cudaOptions.proto().use_readonly_cache(), + cudaOptions.proto().timeout() != 0)); auto& scop = mappedScop->scop_; // 1a. Optionally specialize before scheduling... @@ -1083,6 +1232,11 @@ std::unique_ptr MappedScop::makeWithOuterBlockInnerThreadStrategy( << "After outerBlockInnerThread strategy:" << std::endl << *mappedScop->schedule(); + if (mappedScop->useTimeout) { + mappedScop->insertTimeoutChecks( + mappedScop->scop().scheduleRoot(), FLAGS_timeout_check_frequency); + } + return mappedScop; } diff --git a/tc/core/polyhedral/cuda/mapped_scop.h b/tc/core/polyhedral/cuda/mapped_scop.h index 169b4f138..0d0ee3417 100644 --- a/tc/core/polyhedral/cuda/mapped_scop.h +++ b/tc/core/polyhedral/cuda/mapped_scop.h @@ -62,27 +62,35 @@ class MappedScop { ::tc::Grid grid, ::tc::Block block, uint64_t unroll_, - bool useReadOnlyCache_) + bool useReadOnlyCache_, + bool useTimeout_) : scop_(std::move(scop)), numBlocks(grid), numThreads(block), unroll(unroll_), - useReadOnlyCache(useReadOnlyCache_) {} + useReadOnlyCache(useReadOnlyCache_), + useTimeout(useTimeout_) {} public: static inline std::unique_ptr makeOneBlockOneThread( std::unique_ptr&& scop) { return std::unique_ptr(new MappedScop( - std::move(scop), ::tc::Grid{1, 1, 1}, ::tc::Block{1, 1, 1}, 1, false)); + std::move(scop), + ::tc::Grid{1, 1, 1}, + ::tc::Block{1, 1, 1}, + 1, + false, + false)); } static inline std::unique_ptr makeMappedScop( std::unique_ptr&& scop, ::tc::Grid grid, ::tc::Block block, uint64_t unroll, - bool useReadOnlyCache) { - return std::unique_ptr( - new MappedScop(std::move(scop), grid, block, unroll, useReadOnlyCache)); + bool useReadOnlyCache, + bool useTimeout) { + return std::unique_ptr(new MappedScop( + std::move(scop), grid, block, unroll, useReadOnlyCache, useTimeout)); } // Apply the hand-written OuterBlockInnerThread mapping strategy. @@ -194,6 +202,11 @@ class MappedScop { // Return a pointer to the split off tile. detail::ScheduleTree* splitOutReductionTileAndInsertSyncs( detail::ScheduleTree* band); + + void insertTimeoutChecks( + detail::ScheduleTree* st, + unsigned timeoutCheckFrequency); + // Map "band" to thread identifiers using as many blockSizes values as outer // coincident dimensions (plus reduction dimension, if any), // insert synchronization in case of a reduction, and @@ -214,6 +227,7 @@ class MappedScop { const ::tc::Block numThreads; const uint64_t unroll; const bool useReadOnlyCache; + const bool useTimeout; private: // Information about a detected reduction that can potentially diff --git a/tc/core/polyhedral/schedule_isl_conversion.cc b/tc/core/polyhedral/schedule_isl_conversion.cc index 0f4b2b43a..222201f91 100644 --- a/tc/core/polyhedral/schedule_isl_conversion.cc +++ b/tc/core/polyhedral/schedule_isl_conversion.cc @@ -247,7 +247,9 @@ isl::schedule toIslSchedule(const ScheduleTree* root) { CHECK(domain) << "Root node should be domain node" << *root; auto node = isl::schedule_node::from_domain(domain->domain_); node = extendChild(node, root); - return node.get_schedule(); + auto ans = node.get_schedule(); + return ans; + // return node.get_schedule(); } namespace { diff --git a/tc/core/polyhedral/scop.h b/tc/core/polyhedral/scop.h index e163dbe90..6fd048427 100644 --- a/tc/core/polyhedral/scop.h +++ b/tc/core/polyhedral/scop.h @@ -215,6 +215,14 @@ struct Scop { insertExtensionLabelAfter(scheduleRoot(), tree, makeSyncId(level)); } + void insertTimeoutCheckAfter(detail::ScheduleTree* tree) { + insertExtensionLabelAfter(scheduleRoot(), tree, makeTimeoutCheckId()); + } + + void insertTimeoutCheck(detail::ScheduleTree* seqNode, size_t pos) { + insertExtensionLabelAt(scheduleRoot(), seqNode, pos, makeTimeoutCheckId()); + } + size_t reductionUID() const { static size_t count = 0; return count++; @@ -227,6 +235,10 @@ struct Scop { static size_t count = 0; return count++; } + size_t timeoutCheckUID() const { + static size_t count = 0; + return count++; + } // Make the synchronization id corresponding to the synchronization level. // The level should not be None. @@ -255,6 +267,13 @@ struct Scop { ctx, std::string(kWarpSyncIdPrefix) + std::to_string(warpSyncUID())); } + isl::id makeTimeoutCheckId() const { + auto ctx = domain().get_ctx(); + return isl::id( + ctx, + std::string(kTimeoutCheckPrefix) + std::to_string(timeoutCheckUID())); + } + // Check if the id has a name with the expected prefix, followed by a long // integer. static bool isIdWithExpectedPrefix( @@ -281,6 +300,10 @@ struct Scop { return isIdWithExpectedPrefix(id, kWarpSyncIdPrefix); } + static bool isTimeoutCheckId(isl::id id) { + return isIdWithExpectedPrefix(id, kTimeoutCheckPrefix); + } + static isl::id makeRefId(isl::ctx ctx) { static thread_local size_t count = 0; return isl::id(ctx, std::string("__tc_ref_") + std::to_string(count++)); diff --git a/tc/core/polyhedral/unroll.cc b/tc/core/polyhedral/unroll.cc index 7047bd0d1..9c483d43c 100644 --- a/tc/core/polyhedral/unroll.cc +++ b/tc/core/polyhedral/unroll.cc @@ -23,47 +23,6 @@ namespace tc { namespace polyhedral { namespace { -/* - * Return a bound on the range of values attained by "f" for fixed - * values of "fixed", taking into account basic strides - * in the range of values attained by "f". - * - * First construct a map from values of "fixed" to corresponding - * values of "f". If this map is empty, then "f" cannot attain - * any values and the bound is zero. - * Otherwise, consider pairs of "f" values for the same value - * of "fixed" and take their difference over all possible values - * of the parameters and of the "fixed" values. - * Take a simple overapproximation as a convex set and - * determine the stride is the value differences. - * The possibly quasi-affine set is then overapproximated by an affine set. - * At this point, the set is a possibly infinite, symmetrical interval. - * Take the maximal value of the difference divided by the stride plus one as - * a bound on the number of possible values of "f". - * That is, take M/s + 1. Note that 0 is always an element of - * the difference set, so no offset needs to be taken into account - * during the stride computation and M is an integer multiple of s. - */ -isl::val relativeRange(isl::union_map fixed, isl::union_pw_aff f) { - auto ctx = f.get_ctx(); - auto umap = isl::union_map::from(isl::multi_union_pw_aff(f)); - umap = umap.apply_domain(fixed); - if (umap.is_empty()) { - return isl::val::zero(ctx); - } - - umap = umap.range_product(umap); - umap = umap.range().unwrap(); - umap = umap.project_out_all_params(); - auto delta = isl::map::from_union_map(umap).deltas(); - auto hull = delta.simple_hull(); - auto stride = isl::set(hull).get_stride(0); - hull = isl::set(hull).polyhedral_hull(); - auto bound = hull.dim_max_val(0); - bound = bound.div(stride); - bound = bound.add(isl::val::one(ctx)); - return bound; -} /* * Compute a bound on the number of instances executed by "band" and diff --git a/tc/external/detail/islpp-inl.h b/tc/external/detail/islpp-inl.h index 38edad204..dcbd1c73d 100644 --- a/tc/external/detail/islpp-inl.h +++ b/tc/external/detail/islpp-inl.h @@ -215,4 +215,28 @@ inline isl::set operator&(isl::point P1, isl::set S2) { return S2 & P1; } +/////////////////////////////////////////////////////////////////////////////// +// Helper functions +/////////////////////////////////////////////////////////////////////////////// + +inline isl::val relativeRange(isl::union_map fixed, isl::union_pw_aff f) { + auto ctx = f.get_ctx(); + auto umap = isl::union_map::from(isl::multi_union_pw_aff(f)); + umap = umap.apply_domain(fixed); + if (umap.is_empty()) { + return isl::val::zero(ctx); + } + + umap = umap.range_product(umap); + umap = umap.range().unwrap(); + umap = umap.project_out_all_params(); + auto delta = isl::map::from_union_map(umap).deltas(); + auto hull = delta.simple_hull(); + auto stride = isl::set(hull).get_stride(0); + hull = isl::set(hull).polyhedral_hull(); + auto bound = hull.dim_max_val(0); + bound = bound.div(stride); + bound = bound.add(isl::val::one(ctx)); + return bound; +} } // namespace isl diff --git a/tc/external/detail/islpp.h b/tc/external/detail/islpp.h index affbd51ae..5546d8acc 100644 --- a/tc/external/detail/islpp.h +++ b/tc/external/detail/islpp.h @@ -347,6 +347,29 @@ inline isl::set makeSpecializationSet( return makeSpecializationSet(space, map); } +/* + * Return a bound on the range of values attained by "f" for fixed + * values of "fixed", taking into account basic strides + * in the range of values attained by "f". + * + * First construct a map from values of "fixed" to corresponding + * values of "f". If this map is empty, then "f" cannot attain + * any values and the bound is zero. + * Otherwise, consider pairs of "f" values for the same value + * of "fixed" and take their difference over all possible values + * of the parameters and of the "fixed" values. + * Take a simple overapproximation as a convex set and + * determine the stride is the value differences. + * The possibly quasi-affine set is then overapproximated by an affine set. + * At this point, the set is a possibly infinite, symmetrical interval. + * Take the maximal value of the difference divided by the stride plus one as + * a bound on the number of possible values of "f". + * That is, take M/s + 1. Note that 0 is always an element of + * the difference set, so no offset needs to be taken into account + * during the stride computation and M is an integer multiple of s. + */ +inline isl::val relativeRange(isl::union_map fixed, isl::union_pw_aff f); + namespace detail { // Helper class used to support range-based for loops on isl::*_list types. diff --git a/tc/proto/mapping_options.proto b/tc/proto/mapping_options.proto index ff29e3557..72e36d708 100644 --- a/tc/proto/mapping_options.proto +++ b/tc/proto/mapping_options.proto @@ -70,6 +70,8 @@ message CudaMappingOptionsProto { optional uint64 max_shared_memory = 7; // Use the readonly cache (i.e. emit __ldg loads) required bool use_readonly_cache = 8; + /// If provided, generate timeout checks. The given timeout is in ms. + optional uint32 timeout = 9; } message CpuMappingOptionsProto { diff --git a/test/test_cuda_mapper.cc b/test/test_cuda_mapper.cc index f33371cfb..38c40f4ed 100644 --- a/test/test_cuda_mapper.cc +++ b/test/test_cuda_mapper.cc @@ -91,6 +91,7 @@ struct PolyhedralMapperTest : public ::testing::Test { Grid{1}, Block{blockSizes[0], blockSizes[1]}, 0, + false, false); auto band = mscop->mapBlocksForward(root->child({0}), 1); bandScale(band, tileSizes); @@ -114,6 +115,7 @@ struct PolyhedralMapperTest : public ::testing::Test { Grid{gridSizes[0], gridSizes[1]}, Block{blockSizes[0], blockSizes[1]}, 0, + false, false); // Map to blocks @@ -364,13 +366,15 @@ def fun(float(N, M) A, float(N, M) B) -> (C) { auto res = mscop->codegen(specializedName); - std::string expected( + std::string expectedIds( R"RES(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; - int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; - float32 (*C)[M] = reinterpret_cast(pC); + int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;)RES"); + std::string expectedDeclarations( + R"RES(float32 (*C)[M] = reinterpret_cast(pC); const float32 (*A)[M] = reinterpret_cast(pA); - const float32 (*B)[M] = reinterpret_cast(pB); - for (int c1 = 16 * b1; c1 < M; c1 += 4096) { + const float32 (*B)[M] = reinterpret_cast(pB);)RES"); + std::string expectedCompute( + R"RES(for (int c1 = 16 * b1; c1 < M; c1 += 4096) { if (M >= t0 + c1 + 1) { C[(t1 + 16 * b0)][(t0 + c1)] = (A[(t1 + 16 * b0)][(t0 + c1)] + B[(t1 + 16 * b0)][(t0 + c1)]); } @@ -378,7 +382,11 @@ def fun(float(N, M) A, float(N, M) B) -> (C) { } )RES"); - ASSERT_NE(std::string::npos, std::get<0>(res).find(expected)) + ASSERT_NE(std::string::npos, std::get<0>(res).find(expectedIds)) + << std::get<0>(res); + ASSERT_NE(std::string::npos, std::get<0>(res).find(expectedDeclarations)) + << std::get<0>(res); + ASSERT_NE(std::string::npos, std::get<0>(res).find(expectedCompute)) << std::get<0>(res); ASSERT_EQ(32u, std::get<1>(res).view[0]) << "Improper dim in: " << std::get<1>(res).view; @@ -399,17 +407,21 @@ def fun(float(N, N, N, N) A, float(N, N) B, float(N, N) C, float(N, N) D) // Don't intersect context with the domain and see what happens auto res = std::get<0>(mscop->codegen(specializedName)); - std::string expected( + std::string expectedIds( R"RES(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; - int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; - float32 (*O1)[N] = reinterpret_cast(pO1); + int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;)RES"); + + std::string expectedDeclarations( + R"RES(float32 (*O1)[N] = reinterpret_cast(pO1); float32 (*O2)[N] = reinterpret_cast(pO2); float32 (*O3)[N] = reinterpret_cast(pO3); const float32 (*A)[N][N][N] = reinterpret_cast(pA); const float32 (*B)[N] = reinterpret_cast(pB); const float32 (*C)[N] = reinterpret_cast(pC); - const float32 (*D)[N] = reinterpret_cast(pD); - for (int c0 = 0; c0 < N; c0 += 1) { + const float32 (*D)[N] = reinterpret_cast(pD);)RES"); + + std::string expectedCompute( + R"RES(for (int c0 = 0; c0 < N; c0 += 1) { for (int c1 = 0; c1 < N; c1 += 1) { O1[c0][c1] = 0.000000f; } @@ -436,7 +448,9 @@ def fun(float(N, N, N, N) A, float(N, N) B, float(N, N) C, float(N, N) D) } )RES"); - ASSERT_NE(std::string::npos, res.find(expected)) << res; + ASSERT_NE(std::string::npos, res.find(expectedIds)) << res; + ASSERT_NE(std::string::npos, res.find(expectedDeclarations)) << res; + ASSERT_NE(std::string::npos, res.find(expectedCompute)) << res; } TEST_F(PolyhedralMapperTest, BareVariables) { @@ -450,13 +464,11 @@ def fun(float(N, N) A) -> (O) auto mscop = makeUnmapped(tc); auto res = std::get<0>(mscop->codegen(specializedName)); - string expected( - R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA) { - int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; - int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; - float32 (*O)[N] = reinterpret_cast(pO); - const float32 (*A)[N] = reinterpret_cast(pA); - for (int c0 = 0; c0 < N; c0 += 1) { + string expected1( + R"RES(float32 (*O)[N] = reinterpret_cast(pO); + const float32 (*A)[N] = reinterpret_cast(pA);)RES"); + string expected2( + R"RES(for (int c0 = 0; c0 < N; c0 += 1) { for (int c1 = 0; c1 < N; c1 += 1) { O[c0][c1] = (((A[c0][c1] + float32(c0)) + float32(c1)) + float32(N)); } @@ -464,7 +476,8 @@ def fun(float(N, N) A) -> (O) } )RES"); - ASSERT_NE(std::string::npos, res.find(expected)) << res; + ASSERT_NE(std::string::npos, res.find(expected1)) << res; + ASSERT_NE(std::string::npos, res.find(expected2)) << res; } TEST_F(PolyhedralMapperTest, CudaFunctions) { @@ -479,15 +492,18 @@ def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O) mscop->fixParameters({{"N", 512}}); auto res = std::get<0>(mscop->codegen(specializedName)); - string expected = - R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA, const float32* pB, const float32* pC) { - int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; - int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; - float32 (*O)[512] = reinterpret_cast(pO); + string expectedIds = + R"RES(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; + int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;)RES"; + + string expectedDeclarations = + R"RES(float32 (*O)[512] = reinterpret_cast(pO); const float32 (*A)[512] = reinterpret_cast(pA); const float32 (*B)[512] = reinterpret_cast(pB); - const float32 (*C) = reinterpret_cast(pC); - for (int c0 = 0; c0 <= 511; c0 += 1) { + const float32 (*C) = reinterpret_cast(pC);)RES"; + + string expectedCompute = + R"RES(for (int c0 = 0; c0 <= 511; c0 += 1) { for (int c1 = 0; c1 <= 511; c1 += 1) { O[c0][c1] = (nextafter(C[c0], exp(A[c0][c1])) + log(B[c1][c0])); } @@ -495,16 +511,32 @@ def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O) } )RES"; - ASSERT_NE(std::string::npos, res.find(expected)) << res; + ASSERT_NE(std::string::npos, res.find(expectedIds)) << res; + ASSERT_NE(std::string::npos, res.find(expectedDeclarations)) << res; + ASSERT_NE(std::string::npos, res.find(expectedCompute)) << res; } -constexpr auto kExpectedMatmul_64_64_64 = - R"CUDA(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; - int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; - float32 (*O)[64] = reinterpret_cast(pO); +TEST_F(PolyhedralMapperTest, MergedContexts) { + auto scop = PrepareAndJoinBandsMatMul(); + + // Unit test claims to use the specialized context properly + scop->fixParameters({{"M", 64}, {"N", 64}, {"K", 64}}); + scop->specializeToContext(); + + auto mscop = TileAndMapThreads(std::move(scop), {16, 16}, {32ul, 8ul}); + auto res = std::get<0>(mscop->codegen(specializedName)); + + string expectedIds = + R"CUDA(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; + int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;)CUDA"; + + string expectedDeclarations = + R"CUDA(float32 (*O)[64] = reinterpret_cast(pO); const float32 (*A)[64] = reinterpret_cast(pA); - const float32 (*B)[64] = reinterpret_cast(pB); - for (int c0 = 0; c0 <= 63; c0 += 16) { + const float32 (*B)[64] = reinterpret_cast(pB);)CUDA"; + + string expectedCompute = + R"CUDA(for (int c0 = 0; c0 <= 63; c0 += 16) { for (int c1 = 0; c1 <= 63; c1 += 16) { for (int c2 = t1; c2 <= 15; c2 += 8) { O[(c0 + c2)][(t0 + c1)] = 0.000000f; @@ -517,16 +549,9 @@ constexpr auto kExpectedMatmul_64_64_64 = } )CUDA"; -TEST_F(PolyhedralMapperTest, MergedContexts) { - auto scop = PrepareAndJoinBandsMatMul(); - - // Unit test claims to use the specialized context properly - scop->fixParameters({{"M", 64}, {"N", 64}, {"K", 64}}); - scop->specializeToContext(); - - auto mscop = TileAndMapThreads(std::move(scop), {16, 16}, {32ul, 8ul}); - auto res = std::get<0>(mscop->codegen(specializedName)); - ASSERT_TRUE(std::string::npos != res.find(kExpectedMatmul_64_64_64)) << res; + ASSERT_TRUE(std::string::npos != res.find(expectedIds)) << res; + ASSERT_TRUE(std::string::npos != res.find(expectedDeclarations)) << res; + ASSERT_TRUE(std::string::npos != res.find(expectedCompute)) << res; } TEST_F(PolyhedralMapperTest, Match1) {