Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 16b5d55

Browse files
committed
Add concat test case; fix a type inconsistency issue in codegen_llvm.
1 parent a933afc commit 16b5d55

File tree

3 files changed

+55
-1
lines changed

3 files changed

+55
-1
lines changed

tc/core/polyhedral/codegen_llvm.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,22 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
243243
CodeGen_X86::visit(call);
244244
}
245245
}
246+
246247
void visit(const Halide::Internal::Variable* op) override {
247248
value = getValue(iteratorMap_->at(op->name));
249+
250+
// Generate code for type casting if necessary.
251+
llvm::Type* ty = llvm_type_of(op->type);
252+
if (value->getType() != ty) {
253+
if (op->type.is_int()) {
254+
value = builder->CreateIntCast(value, ty, true);
255+
} else {
256+
CHECK(false) << "Type inconsistency not handled. "
257+
<< "Variable " << op->name << " is " << op->type
258+
<< ", but its corresponding llvm::Value is "
259+
<< toString(value->getType()) << ".";
260+
}
261+
}
248262
}
249263

250264
public:

tc/core/polyhedral/codegen_llvm.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <memory>
1919
#include <string>
20+
#include <type_traits>
2021

2122
#include "llvm/IR/LLVMContext.h"
2223
#include "llvm/IR/Module.h"
@@ -27,7 +28,12 @@
2728

2829
namespace tc {
2930

30-
static inline std::string toString(llvm::Value* llvmObject) {
31+
template <
32+
typename T,
33+
typename std::enable_if<
34+
std::is_base_of<llvm::Value, T>::value ||
35+
std::is_base_of<llvm::Type, T>::value>::type* = nullptr>
36+
static inline std::string toString(T* llvmObject) {
3137
std::string output;
3238
llvm::raw_string_ostream rso(output);
3339
llvmObject->print(rso);

test/test_mapper_llvm.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,40 @@ def convolution(float(N,C,H,W) I, float(O,C,KH,KW) W1, float(O) B) -> (tmp, O1)
242242
checkRtol(output - expected, {I, W1, B}, C * KH * KW, 1e-6);
243243
}
244244

245+
TEST(LLVMCodegen, Concat) {
246+
string tc = R"TC(
247+
def concat(float(M, N) A, float(M, N) B) -> (O1) {
248+
O1(n, i, m) = i == 0 ? A(m, n) : B(m, n) where i in 0:2
249+
}
250+
)TC";
251+
auto N = 16;
252+
auto M = 24;
253+
254+
auto ctx = isl::with_exceptions::globalIslCtx();
255+
auto scop = polyhedral::Scop::makeScop(ctx, tc);
256+
auto context = scop->makeContext(
257+
std::unordered_map<std::string, int>{{"N", N}, {"M", M}});
258+
scop = Scop::makeSpecializedScop(*scop, context);
259+
260+
Jit jit;
261+
jit.codegenScop("concat", *scop);
262+
auto fptr = (void (*)(float*, float*, float*))jit.getSymbolAddress("concat");
263+
264+
at::Tensor A = at::CPU(at::kFloat).rand({M, N});
265+
at::Tensor B = at::CPU(at::kFloat).rand({M, N});
266+
at::Tensor O1 = at::CPU(at::kFloat).rand({N, 2, M});
267+
at::Tensor O1c = at::CPU(at::kFloat).rand({N, 2, M});
268+
269+
for (int n = 0; n < N; ++n) {
270+
for (int m = 0; m < M; ++m) {
271+
O1c[n][0][m] = A[m][n];
272+
O1c[n][1][m] = B[m][n];
273+
}
274+
}
275+
fptr(A.data<float>(), B.data<float>(), O1.data<float>());
276+
checkRtol(O1c - O1, {A, B}, N * M);
277+
}
278+
245279
int main(int argc, char** argv) {
246280
::testing::InitGoogleTest(&argc, argv);
247281
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)