Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit ff1ed36

Browse files
author
Protonu Basu
committed
Add support for strided tensors
This commit is to start support for strided tensors. I made changes to percolate a vector in TensorInfo down to emitCudaKernel to allow codegen to cast strided tensors. This required changes to an unit test to expect the correct cast.
1 parent cc4b1eb commit ff1ed36

File tree

8 files changed

+260
-15
lines changed

8 files changed

+260
-15
lines changed

.test_tc_mapper_output.txt.swp

16 KB
Binary file not shown.

tc/core/cuda/cuda_tc_executor.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,16 @@ CudaCompilationResult CudaBackend::compileWithTcMapper(
9393
auto parameters = mappedScop->scop().getParameterValues();
9494
auto specializedName = specializeKernelName(tcName, parameters);
9595

96+
auto inputsInfo = makeTensorInfoVector(inputs);
97+
9698
// This updates the launch bounds with the actual result from compilation
9799
// with tightening of launch_bounds. What you get is not necessarily what
98100
// you asked for, the autotuner should adapt to that.
99101
std::string source;
100102
Grid grid;
101103
Block block;
102-
std::tie(source, grid, block) = mappedScop->codegen(specializedName);
104+
std::tie(source, grid, block) =
105+
mappedScop->codegen(specializedName, inputsInfo);
103106
LOG_IF(INFO, FLAGS_dump_cuda) << "generatedCuda: " << source << "\n"
104107
<< "grid: " << grid << " block: " << block;
105108

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,23 @@ void emitTensorView(
183183
stringstream& ss,
184184
Halide::OutputImageParam p,
185185
const map<string, Halide::Expr>& paramValues,
186-
bool constInput = false) {
186+
bool constInput = false,
187+
const TensorInfo* tinfo = NULL) {
187188
WS ws;
188189
stringstream ssViewType;
189190
for (int i = 1; i < p.dimensions(); ++i) { // Skip the outermost dimension
190191
Halide::Expr extent = p.parameter().extent_constraint(i);
191192
extent = Halide::Internal::substitute(paramValues, extent);
192193
CHECK(extent.defined())
193194
<< "Undefined extent on input/output tensor. Forward bounds inference should have set these\n";
194-
ssViewType << "[" << extent << "]";
195+
// TODO: Handle non-unit stride in the innermost dimension
196+
if (tinfo && tinfo->strides.size() == p.dimensions() &&
197+
tinfo->strides[p.dimensions() - 1] == 1 &&
198+
tinfo->strides[i - 1] != (tinfo->shape[i] * tinfo->strides[i])) {
199+
ssViewType << "[" << tinfo->strides[i - 1] << "]";
200+
} else {
201+
ssViewType << "[" << extent << "]";
202+
}
195203
}
196204
ss << ws.tab();
197205
ss << (constInput ? "const " : "") << p.type() << " (*" << p.name() << ")"
@@ -216,9 +224,12 @@ void emitTensorViews(
216224
void emitTensorViews(
217225
stringstream& ss,
218226
const vector<Halide::ImageParam>& params,
219-
const map<string, Halide::Expr>& paramValues) {
220-
for (auto p : params) {
221-
emitTensorView(ss, p, paramValues, true);
227+
const map<string, Halide::Expr>& paramValues,
228+
const std::vector<TensorInfo>& inputsInfo = std::vector<TensorInfo>{}) {
229+
for (size_t i = 0; i < params.size(); ++i) {
230+
inputsInfo.size()
231+
? emitTensorView(ss, params[i], paramValues, true, &inputsInfo[i])
232+
: emitTensorView(ss, params[i], paramValues, true);
222233
}
223234
}
224235

@@ -738,7 +749,8 @@ std::unordered_set<isl::id, isl::IslIdIslHash> gatherReadOnlySet(
738749

739750
string emitCudaKernel(
740751
const std::string& specializedName,
741-
const MappedScop& mscop) {
752+
const MappedScop& mscop,
753+
const std::vector<TensorInfo>& inputsInfo) {
742754
// Expecting a schedule with domain root and context first child.
743755
CHECK(mscop.schedule()->elemAs<detail::ScheduleTreeElemDomain>());
744756
CHECK(
@@ -755,7 +767,7 @@ string emitCudaKernel(
755767
emitKernelSignature(ss, specializedName, scop);
756768
emitThreadIdInit(ss, mscop);
757769
emitTensorViews(ss, scop.halide.outputs, paramValues);
758-
emitTensorViews(ss, scop.halide.inputs, paramValues);
770+
emitTensorViews(ss, scop.halide.inputs, paramValues, inputsInfo);
759771
emitTmpDecl(ss, scop);
760772
emitPromotedArrayViewsHalide(ss, scop);
761773
NodeInfoMapType nodeInfoMap;

tc/core/polyhedral/cuda/codegen.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ struct CodegenStatementContext : CodegenContext {
145145

146146
std::string emitCudaKernel(
147147
const std::string& specializedName,
148-
const MappedScop& scop);
148+
const MappedScop& scop,
149+
const std::vector<TensorInfo>& inputsInfo = std::vector<TensorInfo>{});
149150

150151
} // namespace polyhedral
151152
} // namespace tc

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,8 @@ std::unique_ptr<MappedScop> makeSpecializedMappedScop(
910910
// the context of the original scop as top-level
911911
// context node in schedule tree.
912912
std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
913-
const std::string& specializedName) const {
913+
const std::string& specializedName,
914+
const std::vector<TensorInfo>& inputsInfo) const {
914915
validate(schedule());
915916

916917
auto mappedScopForCodegen = makeSpecializedMappedScop(*this);
@@ -927,8 +928,8 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
927928
code << code::cuda::cubBlockReduce;
928929
}
929930
code << "extern \"C\" {" << std::endl
930-
<< emitCudaKernel(specializedName, *mappedScopForCodegen) << "}"
931-
<< std::endl;
931+
<< emitCudaKernel(specializedName, *mappedScopForCodegen, inputsInfo)
932+
<< "}" << std::endl;
932933

933934
return std::make_tuple(
934935
code.str(),

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ class MappedScop {
115115
// Generate CUDA code at the current state of transformation provided a
116116
// name for the generated function.
117117
std::tuple<std::string, tc::Grid, tc::Block> codegen(
118-
const std::string& specializedName) const;
118+
const std::string& specializedName,
119+
const std::vector<TensorInfo>& inputsInfo =
120+
std::vector<TensorInfo>{}) const;
119121

120122
// Accessors..
121123
// Const accessor to schedule of underlying Scop.

test/cuda/test_tc_mapper.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,8 @@ def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
326326
auto res = Check(TC, name, options, inputs, checkFun);
327327
// This test should be modified when strided tensors are handled
328328
std::string expected =
329-
"const float32 (*I0_view)[64] = "
330-
"reinterpret_cast<const float32 (*)[64]>(pI0_view)";
329+
"const float32 (*I0_view)[128] = "
330+
"reinterpret_cast<const float32 (*)[128]>(pI0_view)";
331331
ASSERT_NE(std::string::npos, res.second.find(expected))
332332
<< "In resulting code:\n"
333333
<< res.second << "\nfound unexpected: " << expected;

test_tc_mapper_output.txt

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
Note: Google Test filter = *Strided*
2+
[==========] Running 1 test from 1 test case.
3+
[----------] Global test environment set-up.
4+
[----------] 1 test from TcCudaMapperTest
5+
[ RUN ] TcCudaMapperTest.TensorAddStrided
6+
WARNING:
7+
Reduction without initialization. If O is not pre-initialized before calling the TC function, consider using the !-suffixed reduction operator +=! instead of +=:
8+
9+
def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
10+
O(n, m) += I0_view(n, m) + I1_view(n, m)
11+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~... <--- HERE
12+
}
13+
14+
WARNING:
15+
Reduction without initialization. If O is not pre-initialized before calling the TC function, consider using the !-suffixed reduction operator +=! instead of +=:
16+
17+
def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
18+
O(n, m) += I0_view(n, m) + I1_view(n, m)
19+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~... <--- HERE
20+
}
21+
22+
I0607 13:02:54.070823 21973 cuda_tc_executor.cc:82] tc::CudaMappingOptions::makeNaiveMappingOptions()
23+
.outerScheduleFusionStrategy(tc::FusionStrategy::Preserve3Coincident)
24+
.outerScheduleAllowSkewing(false)
25+
.outerSchedulePositiveOrthant(true)
26+
.intraTileScheduleFusionStrategy(tc::FusionStrategy::Preserve3Coincident)
27+
.intraTileScheduleAllowSkewing(false)
28+
.intraTileSchedulePositiveOrthant(true)
29+
.fixParametersBeforeScheduling(false)
30+
.tile(32, 32, 32)
31+
.unroll(1)
32+
.tileImperfectlyNested(false)
33+
.matchLibraryCalls(false)
34+
.mapToThreads(32, 8)
35+
.mapToBlocks(256, 256)
36+
.useSharedMemory(false)
37+
.usePrivateMemory(false)
38+
.unrollCopyShared(false)
39+
.useReadOnlyCache(false);
40+
I0607 13:02:54.072165 21973 cuda_tc_executor.cc:83] original schedule:
41+
domain(
42+
[M, N] -> { S_0[O_s1_n, O_s1_m] : 0 <= O_s1_n < N and 0 <= O_s1_m < M })
43+
band(n(1) permutable(0) coincident(0) unroll(0)
44+
-----------------------------------------------------------------------
45+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n)] }
46+
-----------------------------------------------------------------------
47+
band(n(1) permutable(0) coincident(0) unroll(0)
48+
-----------------------------------------------------------------------
49+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m)] }
50+
-----------------------------------------------------------------------
51+
I0607 13:02:54.075304 21973 scop.cc:400] After scheduling:
52+
domain(
53+
[M, N] -> { S_0[O_s1_n, O_s1_m] : 0 <= O_s1_n < N and 0 <= O_s1_m < M })
54+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
55+
-----------------------------------------------------------------------
56+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n)] }
57+
-----------------------------------------------------------------------
58+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m)] }
59+
-----------------------------------------------------------------------
60+
I0607 13:02:54.075870 21973 scop.cc:454] After tiling outer:
61+
domain(
62+
[M, N] -> { S_0[O_s1_n, O_s1_m] : 0 <= O_s1_n < N and 0 <= O_s1_m < M })
63+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
64+
-----------------------------------------------------------------------
65+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(floor((O_s1_n)/32))] }
66+
-----------------------------------------------------------------------
67+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(floor((O_s1_m)/32))] }
68+
-----------------------------------------------------------------------
69+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
70+
-----------------------------------------------------------------------
71+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n - 32*floor((O_s1_n)/32))] }
72+
-----------------------------------------------------------------------
73+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m - 32*floor((O_s1_m)/32))] }
74+
-----------------------------------------------------------------------
75+
I0607 13:02:54.078128 21973 mapped_scop.cc:1021] After mapping to threads:
76+
domain(
77+
[M, N] -> { S_0[O_s1_n, O_s1_m] : M = 64 and N = 64 and 0 <= O_s1_n <= 63 and 0 <= O_s1_m <= 63 })
78+
context([M, N, t1, t0, t2, b2, b1, b0] -> { [] : t2 = 0 and b2 = 0 and 0 <= t1 <= 7 and 0 <= t0 <= 31 and 0 <= b1 <= 255 and 0 <= b0 <= 255 })
79+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
80+
-----------------------------------------------------------------------
81+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(floor((O_s1_n)/32))] }
82+
-----------------------------------------------------------------------
83+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(floor((O_s1_m)/32))] }
84+
-----------------------------------------------------------------------
85+
mapping_filter(ids(t1, t0, )
86+
[M, N, t0, t1] -> { S_0[O_s1_n, O_s1_m] : (-t1 + O_s1_n) mod 8 = 0 and (-t0 + O_s1_m) mod 32 = 0 and 0 <= t0 <= 31 and 0 <= t1 <= 7 })
87+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
88+
-----------------------------------------------------------------------
89+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n - 32*floor((O_s1_n)/32))] }
90+
-----------------------------------------------------------------------
91+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m - 32*floor((O_s1_m)/32))] }
92+
-----------------------------------------------------------------------
93+
thread_specific()
94+
I0607 13:02:54.079393 21973 schedule_transforms.cc:391] Resizing scales to 2 entries: 32 32 32
95+
I0607 13:02:54.079439 21973 mapped_scop.cc:1029] After mapping to blocks:
96+
domain(
97+
[M, N] -> { S_0[O_s1_n, O_s1_m] : M = 64 and N = 64 and 0 <= O_s1_n <= 63 and 0 <= O_s1_m <= 63 })
98+
context([M, N, t1, t0, t2, b2, b1, b0] -> { [] : t2 = 0 and b2 = 0 and 0 <= t1 <= 7 and 0 <= t0 <= 31 and 0 <= b1 <= 255 and 0 <= b0 <= 255 })
99+
mapping_filter(ids(b1, b0, )
100+
[M, N, b0, b1] -> { S_0[O_s1_n, O_s1_m] : -31 - 32b1 + O_s1_m <= 8192*floor((O_s1_m)/8192) <= -32b1 + O_s1_m and -31 - 32b0 + O_s1_n <= 8192*floor((O_s1_n)/8192) <= -32b0 + O_s1_n })
101+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
102+
-----------------------------------------------------------------------
103+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_n)/32))] }
104+
-----------------------------------------------------------------------
105+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_m)/32))] }
106+
-----------------------------------------------------------------------
107+
mapping_filter(ids(t1, t0, )
108+
[M, N, t0, t1] -> { S_0[O_s1_n, O_s1_m] : (-t1 + O_s1_n) mod 8 = 0 and (-t0 + O_s1_m) mod 32 = 0 and 0 <= t0 <= 31 and 0 <= t1 <= 7 })
109+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
110+
-----------------------------------------------------------------------
111+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n - 32*floor((O_s1_n)/32))] }
112+
-----------------------------------------------------------------------
113+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m - 32*floor((O_s1_m)/32))] }
114+
-----------------------------------------------------------------------
115+
thread_specific()
116+
I0607 13:02:54.079643 21973 mapped_scop.cc:1083] After outerBlockInnerThread strategy:
117+
domain(
118+
[M, N] -> { S_0[O_s1_n, O_s1_m] : M = 64 and N = 64 and 0 <= O_s1_n <= 63 and 0 <= O_s1_m <= 63 })
119+
context([M, N, t1, t0, t2, b2, b1, b0] -> { [] : t2 = 0 and b2 = 0 and 0 <= t1 <= 7 and 0 <= t0 <= 31 and 0 <= b1 <= 255 and 0 <= b0 <= 255 })
120+
mapping_filter(ids(b1, b0, )
121+
[M, N, b0, b1] -> { S_0[O_s1_n, O_s1_m] : -31 - 32b1 + O_s1_m <= 8192*floor((O_s1_m)/8192) <= -32b1 + O_s1_m and -31 - 32b0 + O_s1_n <= 8192*floor((O_s1_n)/8192) <= -32b0 + O_s1_n })
122+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
123+
-----------------------------------------------------------------------
124+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_n)/32))] }
125+
-----------------------------------------------------------------------
126+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_m)/32))] }
127+
-----------------------------------------------------------------------
128+
mapping_filter(ids(t1, t0, )
129+
[M, N, t0, t1] -> { S_0[O_s1_n, O_s1_m] : (-t1 + O_s1_n) mod 8 = 0 and (-t0 + O_s1_m) mod 32 = 0 and 0 <= t0 <= 31 and 0 <= t1 <= 7 })
130+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
131+
-----------------------------------------------------------------------
132+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n - 32*floor((O_s1_n)/32))] }
133+
-----------------------------------------------------------------------
134+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m - 32*floor((O_s1_m)/32))] }
135+
-----------------------------------------------------------------------
136+
thread_specific()
137+
I0607 13:02:54.079829 21973 cuda_tc_executor.cc:90] Mapped schedule:
138+
domain(
139+
[M, N] -> { S_0[O_s1_n, O_s1_m] : M = 64 and N = 64 and 0 <= O_s1_n <= 63 and 0 <= O_s1_m <= 63 })
140+
context([M, N, t1, t0, t2, b2, b1, b0] -> { [] : t2 = 0 and b2 = 0 and 0 <= t1 <= 7 and 0 <= t0 <= 31 and 0 <= b1 <= 255 and 0 <= b0 <= 255 })
141+
mapping_filter(ids(b1, b0, )
142+
[M, N, b0, b1] -> { S_0[O_s1_n, O_s1_m] : -31 - 32b1 + O_s1_m <= 8192*floor((O_s1_m)/8192) <= -32b1 + O_s1_m and -31 - 32b0 + O_s1_n <= 8192*floor((O_s1_n)/8192) <= -32b0 + O_s1_n })
143+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
144+
-----------------------------------------------------------------------
145+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_n)/32))] }
146+
-----------------------------------------------------------------------
147+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_m)/32))] }
148+
-----------------------------------------------------------------------
149+
mapping_filter(ids(t1, t0, )
150+
[M, N, t0, t1] -> { S_0[O_s1_n, O_s1_m] : (-t1 + O_s1_n) mod 8 = 0 and (-t0 + O_s1_m) mod 32 = 0 and 0 <= t0 <= 31 and 0 <= t1 <= 7 })
151+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
152+
-----------------------------------------------------------------------
153+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n - 32*floor((O_s1_n)/32))] }
154+
-----------------------------------------------------------------------
155+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m - 32*floor((O_s1_m)/32))] }
156+
-----------------------------------------------------------------------
157+
thread_specific()
158+
I0607 13:02:54.091660 21973 mapped_scop.cc:900] Codegen with tightened bounds [blocks:CudaDim(2, 2, 1) @0x7ffefab63f90, threads:CudaDim(32, 8, 1) @0x7ffefab63fd0] for tree:
159+
domain(
160+
[M, N] -> { S_0[O_s1_n, O_s1_m] : M = 64 and N = 64 and 0 <= O_s1_n <= 63 and 0 <= O_s1_m <= 63 })
161+
context([M, N, t1, t0, t2, b2, b1, b0] -> { [] : M = 64 and N = 64 and t2 = 0 and b2 = 0 and 0 <= t1 <= 7 and 0 <= t0 <= 31 and 0 <= b1 <= 1 and 0 <= b0 <= 1 })
162+
mapping_filter(ids(b1, b0, )
163+
[M, N, b0, b1] -> { S_0[O_s1_n, O_s1_m] : -31 - 32b1 + O_s1_m <= 8192*floor((O_s1_m)/8192) <= -32b1 + O_s1_m and -31 - 32b0 + O_s1_n <= 8192*floor((O_s1_n)/8192) <= -32b0 + O_s1_n })
164+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
165+
-----------------------------------------------------------------------
166+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_n)/32))] }
167+
-----------------------------------------------------------------------
168+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_m)/32))] }
169+
-----------------------------------------------------------------------
170+
mapping_filter(ids(t1, t0, )
171+
[M, N, t0, t1] -> { S_0[O_s1_n, O_s1_m] : (-t1 + O_s1_n) mod 8 = 0 and (-t0 + O_s1_m) mod 32 = 0 and 0 <= t0 <= 31 and 0 <= t1 <= 7 })
172+
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
173+
-----------------------------------------------------------------------
174+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n - 32*floor((O_s1_n)/32))] }
175+
-----------------------------------------------------------------------
176+
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m - 32*floor((O_s1_m)/32))] }
177+
-----------------------------------------------------------------------
178+
thread_specific()
179+
I0607 13:02:54.130249 21973 cuda_rtc.cc:58] NVRTC function source:
180+
181+
template<typename T> inline __device__ T floord(T n, T d) {
182+
return n < 0 ? - (-n + d - 1)/d : n / d;
183+
}
184+
#define if_then_else(cond,a,b) ((cond) ? (a) : (b))
185+
186+
// Halide type handling
187+
typedef int int32;
188+
typedef long int64;
189+
typedef float float32;
190+
typedef double float64;
191+
192+
#define inff __int_as_float(0x7f800000)
193+
#define inf __longlong_as_double(0x7ff0000000000000LL)
194+
195+
// Before CUDA 9, syncwarp is a noop since warps are always synchronized.
196+
#if __CUDACC_VER_MAJOR__ < 9
197+
__device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {}
198+
#endif
199+
200+
extern "C" {
201+
__global__ void tensoraddstrided_64_64(int32 M, int32 N, float32* pO, const float32* pI0_view, const float32* pI1_view) {
202+
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
203+
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
204+
float32 (*O)[64] = reinterpret_cast<float32 (*)[64]>(pO);
205+
const float32 (*I0_view)[128] = reinterpret_cast<const float32 (*)[128]>(pI0_view);
206+
const float32 (*I1_view)[128] = reinterpret_cast<const float32 (*)[128]>(pI1_view);
207+
for (int c2 = t1; c2 <= 31; c2 += 8) {
208+
O[(32 * b0 + c2)][(t0 + 32 * b1)] = (O[(32 * b0 + c2)][(t0 + 32 * b1)] + (I0_view[(32 * b0 + c2)][(t0 + 32 * b1)] + I1_view[(32 * b0 + c2)][(t0 + 32 * b1)]));
209+
}
210+
}
211+
}
212+
I0607 13:02:54.348301 21973 cuda_tc_executor.cc:64] [COMPILE] Compiling with host JIT compiler took: 218ms
213+
WARNING:
214+
Reduction without initialization. If O is not pre-initialized before calling the TC function, consider using the !-suffixed reduction operator +=! instead of +=:
215+
216+
def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
217+
O(n, m) += I0_view(n, m) + I1_view(n, m)
218+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~... <--- HERE
219+
}
220+
221+
[ OK ] TcCudaMapperTest.TensorAddStrided (297 ms)
222+
[----------] 1 test from TcCudaMapperTest (297 ms total)
223+
224+
[----------] Global test environment tear-down
225+
[==========] 1 test from 1 test case ran. (298 ms total)
226+
[ PASSED ] 1 test.

0 commit comments

Comments
 (0)