Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit a301a82

Browse files
author
Protonu Basu
committed
Adding a unit test for stride support. The current lack of support means this test will show incorrect strides. The check will be modified once suport for strides is added."
Minor edit to unit test for strided memory to remove call to benchmark
1 parent 9f9e74c commit a301a82

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

test/cuda/test_tc_mapper.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,42 @@ def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) {
297297
::benchmarkKernelOptions(TC, name, inputs, options);
298298
}
299299

300+
///////////////////////////////////////////////////////////////////////////////
301+
// TensorAddStrided
302+
// O(n, m) += I0_view(n, m) * I1_view(n, m)
303+
///////////////////////////////////////////////////////////////////////////////
304+
TEST_F(TcCudaMapperTest, TensorAddStrided) {
305+
N = 64;
306+
M = 64;
307+
at::Tensor I0 = at::CUDA(at::kFloat).rand({N, M});
308+
at::Tensor I0_view =
309+
I0.type().tensor().set_(*I0.storage(), 0, {N, M}, {1, 16});
310+
at::Tensor I1 = at::CUDA(at::kFloat).rand({N, M});
311+
at::Tensor I1_view =
312+
I1.type().tensor().set_(*I1.storage(), 0, {N, M}, {1, 16});
313+
std::vector<at::Tensor> inputs = {I0_view, I1_view};
314+
315+
static constexpr auto TC = R"TC(
316+
def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
317+
O(n, m) += I0_view(n, m) + I1_view(n, m)
318+
}
319+
)TC";
320+
321+
auto checkFun = [](const std::vector<at::Tensor>& ins,
322+
std::vector<at::Tensor>& outs) { return true; };
323+
auto options = tc::CudaMappingOptions::makeNaiveMappingOptions();
324+
auto name = "tensoraddstrided";
325+
auto res = Check(TC, name, options, inputs, checkFun);
326+
// This test should be modified when strided tensors are handled
327+
std::string expected =
328+
"const float32 (*I0_view)[64] = "
329+
"reinterpret_cast<const float32 (*)[64]>(pI0_view)";
330+
331+
ASSERT_NE(std::string::npos, res.second.find(expected))
332+
<< "In resulting code:\n"
333+
<< res.second << "\nfound unexpected: " << expected;
334+
}
335+
300336
///////////////////////////////////////////////////////////////////////////////
301337
// Lookup Table
302338
// O(b, n) +=! LUT(I(b, n), r_r)

0 commit comments

Comments
 (0)