Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 8b9b9a2

Browse files
authored
Merge pull request #443 from protonu/test-strided-tensors
Adding a unit test for stride support.
2 parents df5444e + a301a82 commit 8b9b9a2

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

test/cuda/test_tc_mapper.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,42 @@ def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) {
297297
::benchmarkKernelOptions(TC, name, inputs, options);
298298
}
299299

300+
///////////////////////////////////////////////////////////////////////////////
301+
// TensorAddStrided
302+
// O(n, m) += I0_view(n, m) * I1_view(n, m)
303+
///////////////////////////////////////////////////////////////////////////////
304+
TEST_F(TcCudaMapperTest, TensorAddStrided) {
305+
N = 64;
306+
M = 64;
307+
at::Tensor I0 = at::CUDA(at::kFloat).rand({N, M});
308+
at::Tensor I0_view =
309+
I0.type().tensor().set_(*I0.storage(), 0, {N, M}, {1, 16});
310+
at::Tensor I1 = at::CUDA(at::kFloat).rand({N, M});
311+
at::Tensor I1_view =
312+
I1.type().tensor().set_(*I1.storage(), 0, {N, M}, {1, 16});
313+
std::vector<at::Tensor> inputs = {I0_view, I1_view};
314+
315+
static constexpr auto TC = R"TC(
316+
def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
317+
O(n, m) += I0_view(n, m) + I1_view(n, m)
318+
}
319+
)TC";
320+
321+
auto checkFun = [](const std::vector<at::Tensor>& ins,
322+
std::vector<at::Tensor>& outs) { return true; };
323+
auto options = tc::CudaMappingOptions::makeNaiveMappingOptions();
324+
auto name = "tensoraddstrided";
325+
auto res = Check(TC, name, options, inputs, checkFun);
326+
// This test should be modified when strided tensors are handled
327+
std::string expected =
328+
"const float32 (*I0_view)[64] = "
329+
"reinterpret_cast<const float32 (*)[64]>(pI0_view)";
330+
331+
ASSERT_NE(std::string::npos, res.second.find(expected))
332+
<< "In resulting code:\n"
333+
<< res.second << "\nfound unexpected: " << expected;
334+
}
335+
300336
///////////////////////////////////////////////////////////////////////////////
301337
// Lookup Table
302338
// O(b, n) +=! LUT(I(b, n), r_r)

0 commit comments

Comments
 (0)