Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 6ceae90

Browse files
authored
Merge pull request #469 from facebookresearch/pr/context
emitCudaKernel: store parameter values in specialized scop and take them from there
2 parents c21c93f + 1b390a9 commit 6ceae90

16 files changed

+101
-171
lines changed

tc/core/cuda/cuda_tc_executor.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,8 @@ CudaCompilationResult CudaBackend::compileWithTcMapper(
7777
// context to specialize the scop..
7878
auto scop = polyhedral::Scop::makeScop(
7979
isl::with_exceptions::globalIslCtx(), halideComponents);
80-
auto globalParameterContext = scop->makeContextFromInputs(inputs);
81-
scop = polyhedral::Scop::makeSpecializedScop(
82-
*scop, globalParameterContext.intersect(scop->globalParameterContext));
80+
auto pvm = computeParamValueMap(halideComponents, inputs);
81+
scop = polyhedral::Scop::makeSpecializedScop(*scop, pvm);
8382
LOG_IF(INFO, FLAGS_debug_tc_mapper) << options;
8483
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "original schedule:\n"
8584
<< *(scop->scheduleRoot());
@@ -91,8 +90,7 @@ CudaCompilationResult CudaBackend::compileWithTcMapper(
9190
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "Mapped schedule:" << std::endl
9291
<< *(mappedScop->schedule());
9392

94-
auto parameters =
95-
mappedScop->scop().getParameterValues(globalParameterContext);
93+
auto parameters = mappedScop->scop().getParameterValues();
9694
auto specializedName = specializeKernelName(tcName, parameters);
9795

9896
// This updates the launch bounds with the actual result from compilation

tc/core/halide2isl.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,19 +227,19 @@ isl::aff makeIslAffFromExpr(isl::space space, const Expr& e) {
227227
return list[0];
228228
}
229229

230-
isl::space makeParamSpace(isl::ctx ctx, const SymbolTable& symbolTable) {
230+
isl::space makeParamSpace(isl::ctx ctx, const ParameterVector& params) {
231231
auto space = isl::space(ctx, 0);
232232
// set parameter names
233-
for (auto p : symbolTable.params) {
233+
for (auto p : params) {
234234
space = space.add_param(isl::id(ctx, p.name()));
235235
}
236236
return space;
237237
}
238238

239-
isl::set makeParamContext(isl::ctx ctx, const SymbolTable& symbolTable) {
240-
auto space = makeParamSpace(ctx, symbolTable);
239+
isl::set makeParamContext(isl::ctx ctx, const ParameterVector& params) {
240+
auto space = makeParamSpace(ctx, params);
241241
auto context = isl::set::universe(space);
242-
for (auto p : symbolTable.params) {
242+
for (auto p : params) {
243243
isl::aff a(isl::aff::param_on_domain_space(space, isl::id(ctx, p.name())));
244244
context = context & (a >= 0);
245245
}

tc/core/halide2isl.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,20 @@ namespace halide2isl {
3333
/// Helper functions that participate in translating Halide IR to ISL
3434
///
3535

36+
using ParameterVector = std::vector<Halide::Internal::Parameter>;
3637
/// Find and categorize all variables referenced in a piece of Halide IR
3738
struct SymbolTable {
3839
std::vector<std::string> reductionVars, idxVars;
39-
std::vector<Halide::Internal::Parameter> params;
40+
ParameterVector params;
4041
};
4142
SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components);
4243

43-
/// Make the space of all parameter values from the symbol table
44-
isl::space makeParamSpace(isl::ctx ctx, const SymbolTable& symbolTable);
44+
/// Make the space of all given parameter values
45+
isl::space makeParamSpace(isl::ctx ctx, const ParameterVector& params);
4546

46-
/// Make the parameter set marking all parameters from the symbol table
47+
/// Make the parameter set marking all given parameters
4748
/// as non-negative.
48-
isl::set makeParamContext(isl::ctx ctx, const SymbolTable& symbolTable);
49+
isl::set makeParamContext(isl::ctx ctx, const ParameterVector& params);
4950

5051
/// Make a constant-valued affine function over a space.
5152
isl::aff makeIslAffFromInt(isl::space space, int64_t i);

tc/core/halide_utils.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "tc/core/halide_utils.h"
1717

1818
#include <map>
19+
#include <unordered_map>
1920
#include <vector>
2021

2122
#include "tc/core/flags.h"
@@ -36,10 +37,10 @@ DLDataType fromHalideType(const Halide::Type& type) {
3637
return dtype;
3738
}
3839

39-
std::map<std::string, int> computeParamValueMap(
40+
std::unordered_map<std::string, int> computeParamValueMap(
4041
const tc2halide::HalideComponents& halide,
4142
const std::vector<const DLConstTensor*>& inputsDLT) {
42-
std::map<std::string, int> pvm;
43+
std::unordered_map<std::string, int> pvm;
4344
if (halide.inputs.size() != inputsDLT.size()) {
4445
throw lang::ErrorReport(halide.getDef())
4546
<< "expected " << halide.inputs.size() << " inputs but got "

tc/core/halide_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
#include <chrono>
1919
#include <string>
20-
#include <unordered_set>
20+
#include <unordered_map>
2121
#include <vector>
2222

2323
#include "tc/core/tc2halide.h"
@@ -29,7 +29,7 @@ namespace tc {
2929
/// (metadata of) input tensors with specific shapes, compute a map between TC
3030
/// parametric tensor sizes, represented as strings, and their numerical values
3131
/// with given input sizes.
32-
std::map<std::string, int> computeParamValueMap(
32+
std::unordered_map<std::string, int> computeParamValueMap(
3333
const tc2halide::HalideComponents& components,
3434
const std::vector<const DLConstTensor*>& inputsDLT);
3535

tc/core/polyhedral/codegen_llvm.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,7 @@ llvm::Value* CodeGen_TC::getValue(isl::ast_expr expr) {
324324

325325
class LLVMCodegen {
326326
void collectTensor(const Halide::OutputImageParam& t) {
327-
auto sizes =
328-
getTensorSizesWithoutLeadingDim(t, scop_.globalParameterContext);
327+
auto sizes = getTensorSizesWithoutLeadingDim(t, scop_.context());
329328
if (not sizes.empty()) {
330329
args_.emplace_back(
331330
makePtrToArrayType(halide_cg.llvm_type_of(t.type()), sizes));
@@ -509,7 +508,7 @@ class LLVMCodegen {
509508
CHECK(condLHS);
510509
CHECK_EQ(condLHS.get_id(), iterator);
511510

512-
IslAstExprInterpeter i(scop_.globalParameterContext);
511+
IslAstExprInterpeter i(scop_.context());
513512
auto condRHSVal = i.interpret(cond_expr.get_arg(1));
514513

515514
auto cond = [&]() {

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -747,15 +747,8 @@ string emitCudaKernel(
747747

748748
// Make a map of the specialized scalar parameter values
749749
map<string, Halide::Expr> paramValues;
750-
{
751-
auto set = scop.globalParameterContext;
752-
for (unsigned i = 0; i < set.n_param(); i++) {
753-
auto val = set.plain_get_val_if_fixed(isl::dim_type::param, i);
754-
auto name = set.get_space().get_dim_name(isl::dim_type::param, i);
755-
if (!val.is_nan()) {
756-
paramValues[name] = static_cast<int>(val.get_num_si());
757-
}
758-
}
750+
for (const auto& kvp : scop.parameterValues) {
751+
paramValues[kvp.first] = kvp.second;
759752
}
760753

761754
stringstream ss;

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ std::unique_ptr<MappedScop> makeSpecializedMappedScop(
884884
// outer schedule dimensions, so the space of a parameter context code is that
885885
// of a zero-dimensional space.
886886
auto root = scop->scheduleRoot();
887-
updateTopLevelContext(root, scop->globalParameterContext.from_params());
887+
updateTopLevelContext(root, scop->context().from_params());
888888

889889
tc::Grid grid = mappedScop.numBlocks;
890890
tc::Block block = mappedScop.numThreads;
@@ -907,7 +907,7 @@ std::unique_ptr<MappedScop> makeSpecializedMappedScop(
907907
} // namespace
908908

909909
// Before generating code, make a copy of the scop and insert
910-
// the globalParameterContext of the original scop as top-level
910+
// the context of the original scop as top-level
911911
// context node in schedule tree.
912912
std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
913913
const std::string& specializedName) const {

tc/core/polyhedral/cuda/tighten_launch_bounds.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ std::pair<tc::Grid, tc::Block> tightenLaunchBounds(
8484
const tc::Grid& grid,
8585
const tc::Block& block) {
8686
auto root = scop.scheduleRoot();
87-
auto params = scop.globalParameterContext;
87+
auto params = scop.context();
8888

8989
auto max = [root, params](const mapping::MappingId& id) -> size_t {
9090
size_t sizetMax = std::numeric_limits<size_t>::max();

tc/core/polyhedral/memory_promotion.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,10 +421,7 @@ isl::set tensorElementsSet(const Scop& scop, isl::id tensorId) {
421421
(isl::aff_set(aff) < (minAff + extentAff));
422422
}
423423

424-
if (scop.globalParameterContext) {
425-
tensorElements =
426-
tensorElements.intersect_params(scop.globalParameterContext);
427-
}
424+
tensorElements = tensorElements.intersect_params(scop.context());
428425
return tensorElements;
429426
}
430427

0 commit comments

Comments
 (0)