Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 6dc1fbc

Browse files
author
Sven Verdoolaege
committed
test
1 parent e3323ed commit 6dc1fbc

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

test/cuda/test_tc_mapper_bugs.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "tc/core/cuda/cuda_tc_executor.h"
2525
#include "tc/core/flags.h"
2626
#include "tc/core/polyhedral/exceptions.h"
27+
#include "tc/library/matmul.h"
2728

2829
#include "test_harness_aten_cuda.h"
2930

@@ -39,6 +40,37 @@ using namespace tc;
3940
// "Bug" suffix and fly away in the sun.
4041
///////////////////////////////////////////////////////////////////////////////
4142

43+
TEST(A, B) {
44+
auto TC = makeMatmulTc();
45+
auto options =
46+
tc::CudaMappingOptions::makeNaiveMappingOptions()
47+
.outerScheduleFusionStrategy(tc::FusionStrategy::Max)
48+
.outerScheduleAllowSkewing(false)
49+
.outerSchedulePositiveOrthant(true)
50+
.intraTileScheduleFusionStrategy(tc::FusionStrategy::Min)
51+
.intraTileScheduleAllowSkewing(false)
52+
.intraTileSchedulePositiveOrthant(true)
53+
.fixParametersBeforeScheduling(false)
54+
.tile(56, 32, 4, 14, 16)
55+
.unroll(16)
56+
.tileImperfectlyNested(false)
57+
.matchLibraryCalls(false)
58+
.mapToThreads(4, 128)
59+
.mapToBlocks(1, 32, 32)
60+
.useSharedMemory(false)
61+
.usePrivateMemory(true)
62+
.unrollCopyShared(false)
63+
.useReadOnlyCache(false);
64+
uint32_t N = 100, K = 400, M = 500;
65+
at::Tensor A = at::CUDA(at::kFloat).rand({N, K});
66+
at::Tensor B = at::CUDA(at::kFloat).rand({K, M});
67+
std::vector<at::Tensor> inputs = {A, B};
68+
auto pExecutor =
69+
tc::aten::compile<tc::CudaBackend>(TC, "matmul", inputs, options);
70+
auto outputs = tc::aten::prepareOutputs(TC, "matmul", inputs);
71+
tc::aten::run(*pExecutor, inputs, outputs);
72+
}
73+
4274
std::string makeUniqueName(const std::string& name) {
4375
static int count = 0;
4476
return name + std::string("_cnt") + std::to_string(++count);

0 commit comments

Comments
 (0)