@@ -297,6 +297,42 @@ def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) {
297
297
::benchmarkKernelOptions (TC, name, inputs, options);
298
298
}
299
299
300
+ // /////////////////////////////////////////////////////////////////////////////
301
+ // TensorAddStrided
302
+ // O(n, m) += I0_view(n, m) * I1_view(n, m)
303
+ // /////////////////////////////////////////////////////////////////////////////
304
+ TEST_F (TcCudaMapperTest, TensorAddStrided) {
305
+ N = 64 ;
306
+ M = 64 ;
307
+ at::Tensor I0 = at::CUDA (at::kFloat ).rand ({N, M});
308
+ at::Tensor I0_view =
309
+ I0.type ().tensor ().set_ (*I0.storage (), 0 , {N, M}, {1 , 16 });
310
+ at::Tensor I1 = at::CUDA (at::kFloat ).rand ({N, M});
311
+ at::Tensor I1_view =
312
+ I1.type ().tensor ().set_ (*I1.storage (), 0 , {N, M}, {1 , 16 });
313
+ std::vector<at::Tensor> inputs = {I0_view, I1_view};
314
+
315
+ static constexpr auto TC = R"TC(
316
+ def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
317
+ O(n, m) += I0_view(n, m) + I1_view(n, m)
318
+ }
319
+ )TC" ;
320
+
321
+ auto checkFun = [](const std::vector<at::Tensor>& ins,
322
+ std::vector<at::Tensor>& outs) { return true ; };
323
+ auto options = tc::CudaMappingOptions::makeNaiveMappingOptions ();
324
+ auto name = " tensoraddstrided" ;
325
+ auto res = Check (TC, name, options, inputs, checkFun);
326
+ // This test should be modified when strided tensors are handled
327
+ std::string expected =
328
+ " const float32 (*I0_view)[64] = "
329
+ " reinterpret_cast<const float32 (*)[64]>(pI0_view)" ;
330
+
331
+ ASSERT_NE (std::string::npos, res.second .find (expected))
332
+ << " In resulting code:\n "
333
+ << res.second << " \n found unexpected: " << expected;
334
+ }
335
+
300
336
// /////////////////////////////////////////////////////////////////////////////
301
337
// Lookup Table
302
338
// O(b, n) +=! LUT(I(b, n), r_r)
0 commit comments