Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 18c911a

Browse files
authored
Merge pull request #17 from facebookresearch/concat-test
Add test case for concat operation
2 parents 8850eb0 + 8cdae25 commit 18c911a

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

test/test_execution_engine.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,24 @@ struct ATenCompilationUnitTest : public ::testing::Test {
4545
}
4646
};
4747

48+
TEST_F(ATenCompilationUnitTest, DISABLED_Concat) {
49+
at::Tensor a = at::CUDA(at::kFloat).rand({32, 16});
50+
at::Tensor b = at::CUDA(at::kFloat).rand({32, 16});
51+
std::vector<at::Tensor> inputs = {a, b};
52+
std::vector<at::Tensor> outputs;
53+
54+
Check(
55+
R"(
56+
def concat(float(M, N) A, float(M, N) B) -> (O1, O2) {
57+
O1(n, i, m) = i == 0 ? A(m, n) : B(m, n) where i in 0:2
58+
}
59+
)",
60+
"concat",
61+
tc::MappingOptions::makeNaiveMappingOptions(),
62+
inputs,
63+
outputs);
64+
}
65+
4866
TEST_F(ATenCompilationUnitTest, Indexing) {
4967
at::Tensor a = at::CUDA(at::kFloat).rand({3, 4});
5068
at::Tensor b = at::CUDA(at::kInt).ones({2});

0 commit comments

Comments
 (0)