Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 9cd8fc9

Browse files
author
Protonu Basu
committed
Add support for strided tensors
This commit is to start support for strided tensors. I made changes to percolate a vector in TensorInfo down to emitCudaKernel to allow codegen to cast strided tensors. This required changes to an unit test to expect the correct cast.
1 parent cc4b1eb commit 9cd8fc9

File tree

6 files changed

+34
-15
lines changed

6 files changed

+34
-15
lines changed

tc/core/cuda/cuda_tc_executor.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,16 @@ CudaCompilationResult CudaBackend::compileWithTcMapper(
9393
auto parameters = mappedScop->scop().getParameterValues();
9494
auto specializedName = specializeKernelName(tcName, parameters);
9595

96+
auto inputsInfo = makeTensorInfoVector(inputs);
97+
9698
// This updates the launch bounds with the actual result from compilation
9799
// with tightening of launch_bounds. What you get is not necessarily what
98100
// you asked for, the autotuner should adapt to that.
99101
std::string source;
100102
Grid grid;
101103
Block block;
102-
std::tie(source, grid, block) = mappedScop->codegen(specializedName);
104+
std::tie(source, grid, block) =
105+
mappedScop->codegen(specializedName, inputsInfo);
103106
LOG_IF(INFO, FLAGS_dump_cuda) << "generatedCuda: " << source << "\n"
104107
<< "grid: " << grid << " block: " << block;
105108

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,23 @@ void emitTensorView(
183183
stringstream& ss,
184184
Halide::OutputImageParam p,
185185
const map<string, Halide::Expr>& paramValues,
186-
bool constInput = false) {
186+
bool constInput = false,
187+
const TensorInfo* tinfo = NULL) {
187188
WS ws;
188189
stringstream ssViewType;
189190
for (int i = 1; i < p.dimensions(); ++i) { // Skip the outermost dimension
190191
Halide::Expr extent = p.parameter().extent_constraint(i);
191192
extent = Halide::Internal::substitute(paramValues, extent);
192193
CHECK(extent.defined())
193194
<< "Undefined extent on input/output tensor. Forward bounds inference should have set these\n";
194-
ssViewType << "[" << extent << "]";
195+
// TODO: Handle non-unit stride in the innermost dimension
196+
if (tinfo && tinfo->strides.size() == p.dimensions() &&
197+
tinfo->strides[p.dimensions() - 1] == 1 &&
198+
tinfo->strides[i - 1] != (tinfo->shape[i] * tinfo->strides[i])) {
199+
ssViewType << "[" << tinfo->strides[i - 1] << "]";
200+
} else {
201+
ssViewType << "[" << extent << "]";
202+
}
195203
}
196204
ss << ws.tab();
197205
ss << (constInput ? "const " : "") << p.type() << " (*" << p.name() << ")"
@@ -216,9 +224,12 @@ void emitTensorViews(
216224
void emitTensorViews(
217225
stringstream& ss,
218226
const vector<Halide::ImageParam>& params,
219-
const map<string, Halide::Expr>& paramValues) {
220-
for (auto p : params) {
221-
emitTensorView(ss, p, paramValues, true);
227+
const map<string, Halide::Expr>& paramValues,
228+
const std::vector<TensorInfo>& inputsInfo) {
229+
for (size_t i = 0; i < params.size(); ++i) {
230+
inputsInfo.size()
231+
? emitTensorView(ss, params[i], paramValues, true, &inputsInfo[i])
232+
: emitTensorView(ss, params[i], paramValues, true);
222233
}
223234
}
224235

@@ -738,7 +749,8 @@ std::unordered_set<isl::id, isl::IslIdIslHash> gatherReadOnlySet(
738749

739750
string emitCudaKernel(
740751
const std::string& specializedName,
741-
const MappedScop& mscop) {
752+
const MappedScop& mscop,
753+
const std::vector<TensorInfo>& inputsInfo) {
742754
// Expecting a schedule with domain root and context first child.
743755
CHECK(mscop.schedule()->elemAs<detail::ScheduleTreeElemDomain>());
744756
CHECK(
@@ -755,7 +767,7 @@ string emitCudaKernel(
755767
emitKernelSignature(ss, specializedName, scop);
756768
emitThreadIdInit(ss, mscop);
757769
emitTensorViews(ss, scop.halide.outputs, paramValues);
758-
emitTensorViews(ss, scop.halide.inputs, paramValues);
770+
emitTensorViews(ss, scop.halide.inputs, paramValues, inputsInfo);
759771
emitTmpDecl(ss, scop);
760772
emitPromotedArrayViewsHalide(ss, scop);
761773
NodeInfoMapType nodeInfoMap;

tc/core/polyhedral/cuda/codegen.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ struct CodegenStatementContext : CodegenContext {
145145

146146
std::string emitCudaKernel(
147147
const std::string& specializedName,
148-
const MappedScop& scop);
148+
const MappedScop& scop,
149+
const std::vector<TensorInfo>& inputsInfo);
149150

150151
} // namespace polyhedral
151152
} // namespace tc

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,8 @@ std::unique_ptr<MappedScop> makeSpecializedMappedScop(
910910
// the context of the original scop as top-level
911911
// context node in schedule tree.
912912
std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
913-
const std::string& specializedName) const {
913+
const std::string& specializedName,
914+
const std::vector<TensorInfo>& inputsInfo) const {
914915
validate(schedule());
915916

916917
auto mappedScopForCodegen = makeSpecializedMappedScop(*this);
@@ -927,8 +928,8 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
927928
code << code::cuda::cubBlockReduce;
928929
}
929930
code << "extern \"C\" {" << std::endl
930-
<< emitCudaKernel(specializedName, *mappedScopForCodegen) << "}"
931-
<< std::endl;
931+
<< emitCudaKernel(specializedName, *mappedScopForCodegen, inputsInfo)
932+
<< "}" << std::endl;
932933

933934
return std::make_tuple(
934935
code.str(),

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ class MappedScop {
115115
// Generate CUDA code at the current state of transformation provided a
116116
// name for the generated function.
117117
std::tuple<std::string, tc::Grid, tc::Block> codegen(
118-
const std::string& specializedName) const;
118+
const std::string& specializedName,
119+
const std::vector<TensorInfo>& inputsInfo =
120+
std::vector<TensorInfo>{}) const;
119121

120122
// Accessors..
121123
// Const accessor to schedule of underlying Scop.

test/cuda/test_tc_mapper.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,8 @@ def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
326326
auto res = Check(TC, name, options, inputs, checkFun);
327327
// This test should be modified when strided tensors are handled
328328
std::string expected =
329-
"const float32 (*I0_view)[64] = "
330-
"reinterpret_cast<const float32 (*)[64]>(pI0_view)";
329+
"const float32 (*I0_view)[128] = "
330+
"reinterpret_cast<const float32 (*)[128]>(pI0_view)";
331331
ASSERT_NE(std::string::npos, res.second.find(expected))
332332
<< "In resulting code:\n"
333333
<< res.second << "\nfound unexpected: " << expected;

0 commit comments

Comments
 (0)