Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 6a38863

Browse files
nicolasvasilacheftynse
authored andcommitted
Make Halide output standard types
Halide has its own way of pretty-printing types. In the case of `bool` (i.e. `uint1`), this conflicts with the default cuda types. It will also conflict for `(u)int2` and `(u)int4`. This commit make our `Halide::IRPrinter` print `bool` and other types properly for CUDA. As a consequence the typedefs in `cuda_libraries.h` can be removed.
1 parent 012e970 commit 6a38863

File tree

6 files changed

+103
-72
lines changed

6 files changed

+103
-72
lines changed

tc/core/cuda/cuda_libraries.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,6 @@ constexpr auto types = R"C(
3434
// Can't include system dependencies with NVRTC
3535
// Can't include cuda_fp16.h with NVRTC due to transitive system dependencies
3636
// #include <cuda_fp16.h>
37-
38-
// Halide type handling
39-
typedef char int8;
40-
typedef short int16;
41-
typedef int int32;
42-
typedef long int64;
43-
typedef unsigned char uint8;
44-
typedef unsigned short uint16;
45-
typedef unsigned int uint32;
46-
typedef unsigned long uint64;
47-
// typedef half float16;
48-
typedef float float32;
49-
typedef double float64;
5037
)C";
5138

5239
constexpr auto defines = R"C(

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,37 @@ namespace polyhedral {
3939

4040
namespace {
4141

42+
static std::string halideTypeString(const Halide::Type& t) {
43+
if (t.is_bool()) {
44+
return "bool";
45+
} else if (t.is_int() && t.bits() == 8) {
46+
return "char";
47+
} else if (t.is_int() && t.bits() == 16) {
48+
return "short";
49+
} else if (t.is_int() && t.bits() == 32) {
50+
return "int";
51+
} else if (t.is_int() && t.bits() == 64) {
52+
return "long";
53+
} else if (t.is_uint() && t.bits() == 8) {
54+
return "unsigned char";
55+
} else if (t.is_uint() && t.bits() == 16) {
56+
return "unsigned short";
57+
} else if (t.is_uint() && t.bits() == 32) {
58+
return "unsigned int";
59+
} else if (t.is_uint() && t.bits() == 64) {
60+
return "unsigned long";
61+
} else if (t.is_float() && t.bits() == 16) {
62+
return "half";
63+
} else if (t.is_float() && t.bits() == 32) {
64+
return "float";
65+
} else if (t.is_float() && t.bits() == 64) {
66+
return "double";
67+
}
68+
std::stringstream ss;
69+
ss << t;
70+
return ss.str();
71+
}
72+
4273
struct WS {
4374
static thread_local int n;
4475
WS() {
@@ -102,7 +133,7 @@ vector<string> emitParams(const Scop& scop) {
102133
// Halide params. One of these two vectors will be empty.
103134
for (auto p : scop.halide.params) {
104135
stringstream ss;
105-
ss << p.type() << " " << p.name();
136+
ss << halideTypeString(p.type()) << " " << p.name();
106137
res.push_back(ss.str());
107138
}
108139
return res;
@@ -113,7 +144,7 @@ string emitTypedTensorName(
113144
Halide::OutputImageParam t,
114145
bool constInput = false) {
115146
stringstream ss;
116-
ss << (constInput ? "const " : "") << t.type() << "* "
147+
ss << (constInput ? "const " : "") << halideTypeString(t.type()) << "* "
117148
<< makePointerName(t.name());
118149
return ss.str();
119150
}
@@ -195,11 +226,11 @@ void emitTensorView(
195226
ssViewType << "[" << extent << "]";
196227
}
197228
ss << ws.tab();
198-
ss << (constInput ? "const " : "") << p.type() << " (*" << p.name() << ")"
199-
<< ssViewType.str();
229+
ss << (constInput ? "const " : "") << halideTypeString(p.type()) << " (*"
230+
<< p.name() << ")" << ssViewType.str();
200231
ss << " = ";
201-
ss << "reinterpret_cast<" << (constInput ? "const " : "") << p.type()
202-
<< " (*)" << ssViewType.str() << ">";
232+
ss << "reinterpret_cast<" << (constInput ? "const " : "")
233+
<< halideTypeString(p.type()) << " (*)" << ssViewType.str() << ">";
203234
ss << "(" << makePointerName(p.name()) << ")";
204235
ss << ";";
205236
ss << endl;
@@ -604,6 +635,21 @@ void emitHalideExpr(
604635
IRPrinter::visit(op);
605636
}
606637
}
638+
void visit(const Halide::Internal::IntImm* op) {
639+
context.ss << "(" << halideTypeString(op->type) << ")" << op->value;
640+
}
641+
void visit(const Halide::Internal::UIntImm* op) {
642+
context.ss << "(" << halideTypeString(op->type) << ")" << op->value;
643+
}
644+
void visit(const Halide::Internal::FloatImm* op) {
645+
context.ss << "(" << halideTypeString(op->type) << ")" << op->value;
646+
}
647+
void visit(const Halide::Internal::Cast* op) {
648+
context.ss << "(" << halideTypeString(op->type) << ")";
649+
context.ss << "(";
650+
op->value.accept(this);
651+
context.ss << ")";
652+
}
607653
// TODO: handle casts
608654
const CodegenStatementContext& context;
609655
const map<string, string>& substitutions;
@@ -720,7 +766,7 @@ void emitTmpDecl(stringstream& ss, const Scop& scop) {
720766
auto updateId = kvp.second;
721767
auto provide =
722768
scop.halide.statements.at(updateId).as<Halide::Internal::Provide>();
723-
ss << provide->values[0].type() << " "
769+
ss << halideTypeString(provide->values[0].type()) << " "
724770
<< makeReductionTmpName(updateId, scop) << ";" << endl;
725771
}
726772
}
@@ -745,7 +791,7 @@ void emitPromotedArrayViewsHalide(stringstream& ss, const Scop& scop) {
745791
if (p.second.kind == Scop::PromotedDecl::Kind::SharedMem) {
746792
ss << "__shared__ ";
747793
}
748-
ss << t << " " << viewName;
794+
ss << halideTypeString(t) << " " << viewName;
749795
for (auto s : p.second.sizes) {
750796
ss << "[" << s << "]";
751797
}

test/cuda/test_basic_gpu.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,15 @@ void loadUnload(const std::string& ptx) {
9999

100100
TEST(BasicGpuTest, Nvrtc) {
101101
TC_CUDA_RUNTIMEAPI_ENFORCE(cudaFree(0));
102-
auto PTX = jitCompile(R"CUDA(
102+
auto PTX = jitCompile(
103+
R"CUDA(
103104
extern "C" {
104105
__global__ void foo(int N)
105106
{
106107
assert(N == 1);
107108
}
108-
})CUDA", {"-G"});
109+
})CUDA",
110+
{"-G"});
109111

110112
std::string ptx(PTX.data());
111113
loadUnload(ptx);
@@ -153,15 +155,12 @@ namespace {
153155
// Mark the function argument as __restrict__ depending on the flag.
154156
std::string makeFuncWithOptionalRestrict(bool useRestrict) {
155157
std::stringstream ss;
156-
ss << R"CUDA(typedef float float32;
157-
extern "C" {
158-
)CUDA";
159158
ss
160-
<< (useRestrict ? "__global__ void func(float32* __restrict__ pO2) {"
161-
: "__global__ void func(float32* pO2) {");
159+
<< (useRestrict ? "__global__ void func(float* __restrict__ pO2) {"
160+
: "__global__ void func(float* pO2) {");
162161
ss << R"CUDA(int b0 = blockIdx.x;
163162
int t0 = threadIdx.x;
164-
float32 (*O2)[2] = reinterpret_cast<float32 (*)[2]>(pO2);
163+
float (*O2)[2] = reinterpret_cast<float (*)[2]>(pO2);
165164
O2[b0][t0] = 0.000000f; // S1
166165
__syncthreads();
167166
if (t0 == 0) {
@@ -171,7 +170,6 @@ extern "C" {
171170
}
172171
__syncthreads();
173172
O2[b0][t0] = fmax(O2[b0][t0], 0); // S3
174-
}
175173
})CUDA";
176174
return ss.str();
177175
}

test/cuda/test_tc_mapper.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,8 @@ def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
327327
auto res = Check(TC, name, options, inputs, checkFun);
328328
// This test should be modified when strided tensors are handled
329329
std::string expected =
330-
"const float32 (*I0_view)[64] = "
331-
"reinterpret_cast<const float32 (*)[64]>(pI0_view)";
330+
"const float (*I0_view)[64] = "
331+
"reinterpret_cast<const float (*)[64]>(pI0_view)";
332332
ASSERT_NE(std::string::npos, res.second.find(expected))
333333
<< "In resulting code:\n"
334334
<< res.second << "\nfound unexpected: " << expected;

test/test_cuda_mapper.cc

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,9 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
365365
std::string expected(
366366
R"RES(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
367367
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
368-
float32 (*C)[M] = reinterpret_cast<float32 (*)[M]>(pC);
369-
const float32 (*A)[M] = reinterpret_cast<const float32 (*)[M]>(pA);
370-
const float32 (*B)[M] = reinterpret_cast<const float32 (*)[M]>(pB);
368+
float (*C)[M] = reinterpret_cast<float (*)[M]>(pC);
369+
const float (*A)[M] = reinterpret_cast<const float (*)[M]>(pA);
370+
const float (*B)[M] = reinterpret_cast<const float (*)[M]>(pB);
371371
for (int c1 = 16 * b1; c1 < M; c1 += 4096) {
372372
if (M >= t0 + c1 + 1) {
373373
C[(t1 + 16 * b0)][(t0 + c1)] = (A[(t1 + 16 * b0)][(t0 + c1)] + B[(t1 + 16 * b0)][(t0 + c1)]);
@@ -400,16 +400,16 @@ def fun(float(N, N, N, N) A, float(N, N) B, float(N, N) C, float(N, N) D)
400400
std::string expected(
401401
R"RES(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
402402
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
403-
float32 (*O1)[N] = reinterpret_cast<float32 (*)[N]>(pO1);
404-
float32 (*O2)[N] = reinterpret_cast<float32 (*)[N]>(pO2);
405-
float32 (*O3)[N] = reinterpret_cast<float32 (*)[N]>(pO3);
406-
const float32 (*A)[N][N][N] = reinterpret_cast<const float32 (*)[N][N][N]>(pA);
407-
const float32 (*B)[N] = reinterpret_cast<const float32 (*)[N]>(pB);
408-
const float32 (*C)[N] = reinterpret_cast<const float32 (*)[N]>(pC);
409-
const float32 (*D)[N] = reinterpret_cast<const float32 (*)[N]>(pD);
403+
float (*O1)[N] = reinterpret_cast<float (*)[N]>(pO1);
404+
float (*O2)[N] = reinterpret_cast<float (*)[N]>(pO2);
405+
float (*O3)[N] = reinterpret_cast<float (*)[N]>(pO3);
406+
const float (*A)[N][N][N] = reinterpret_cast<const float (*)[N][N][N]>(pA);
407+
const float (*B)[N] = reinterpret_cast<const float (*)[N]>(pB);
408+
const float (*C)[N] = reinterpret_cast<const float (*)[N]>(pC);
409+
const float (*D)[N] = reinterpret_cast<const float (*)[N]>(pD);
410410
for (int c0 = 0; c0 < N; c0 += 1) {
411411
for (int c1 = 0; c1 < N; c1 += 1) {
412-
O1[c0][c1] = 0.000000f;
412+
O1[c0][c1] = (float)0.000000;
413413
}
414414
}
415415
for (int c0 = 0; c0 < N; c0 += 1) {
@@ -449,14 +449,14 @@ def fun(float(N, N) A) -> (O)
449449
auto res = std::get<0>(mscop->codegen(specializedName));
450450

451451
string expected(
452-
R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA) {
452+
R"RES(__global__ void kernel_anon(int N, float* pO, const float* pA) {
453453
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
454454
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
455-
float32 (*O)[N] = reinterpret_cast<float32 (*)[N]>(pO);
456-
const float32 (*A)[N] = reinterpret_cast<const float32 (*)[N]>(pA);
455+
float (*O)[N] = reinterpret_cast<float (*)[N]>(pO);
456+
const float (*A)[N] = reinterpret_cast<const float (*)[N]>(pA);
457457
for (int c0 = 0; c0 < N; c0 += 1) {
458458
for (int c1 = 0; c1 < N; c1 += 1) {
459-
O[c0][c1] = (((A[c0][c1] + float32(c0)) + float32(c1)) + float32(N));
459+
O[c0][c1] = (((A[c0][c1] + (float)(c0)) + (float)(c1)) + (float)(N));
460460
}
461461
}
462462
}
@@ -478,13 +478,13 @@ def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O)
478478
auto res = std::get<0>(mscop->codegen(specializedName));
479479

480480
string expected =
481-
R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA, const float32* pB, const float32* pC) {
481+
R"RES(__global__ void kernel_anon(int N, float* pO, const float* pA, const float* pB, const float* pC) {
482482
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
483483
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
484-
float32 (*O)[512] = reinterpret_cast<float32 (*)[512]>(pO);
485-
const float32 (*A)[512] = reinterpret_cast<const float32 (*)[512]>(pA);
486-
const float32 (*B)[512] = reinterpret_cast<const float32 (*)[512]>(pB);
487-
const float32 (*C) = reinterpret_cast<const float32 (*)>(pC);
484+
float (*O)[512] = reinterpret_cast<float (*)[512]>(pO);
485+
const float (*A)[512] = reinterpret_cast<const float (*)[512]>(pA);
486+
const float (*B)[512] = reinterpret_cast<const float (*)[512]>(pB);
487+
const float (*C) = reinterpret_cast<const float (*)>(pC);
488488
for (int c0 = 0; c0 <= 511; c0 += 1) {
489489
for (int c1 = 0; c1 <= 511; c1 += 1) {
490490
O[c0][c1] = (nextafter(C[c0], exp(A[c0][c1])) + log(B[c1][c0]));
@@ -499,13 +499,13 @@ def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O)
499499
constexpr auto kExpectedMatmul_64_64_64 =
500500
R"CUDA(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
501501
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
502-
float32 (*O)[64] = reinterpret_cast<float32 (*)[64]>(pO);
503-
const float32 (*A)[64] = reinterpret_cast<const float32 (*)[64]>(pA);
504-
const float32 (*B)[64] = reinterpret_cast<const float32 (*)[64]>(pB);
502+
float (*O)[64] = reinterpret_cast<float (*)[64]>(pO);
503+
const float (*A)[64] = reinterpret_cast<const float (*)[64]>(pA);
504+
const float (*B)[64] = reinterpret_cast<const float (*)[64]>(pB);
505505
for (int c0 = 0; c0 <= 63; c0 += 16) {
506506
for (int c1 = 0; c1 <= 63; c1 += 16) {
507507
for (int c2 = t1; c2 <= 15; c2 += 8) {
508-
O[(c0 + c2)][(t0 + c1)] = 0.000000f;
508+
O[(c0 + c2)][(t0 + c1)] = (float)0.000000;
509509
for (int c4 = 0; c4 <= 63; c4 += 1) {
510510
O[(c0 + c2)][(t0 + c1)] = (O[(c0 + c2)][(t0 + c1)] + (A[(c0 + c2)][c4]*B[c4][(t0 + c1)]));
511511
}

test/test_cuda_mapper_memory_promotion.cc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ def fun(float(N,M,K,L) A, float(N,M,K,L) B) -> (C) {
113113
};
114114

115115
TEST_F(Sum4D, CodeOuterBand) {
116-
auto declarations = {"__shared__ float32 _A_0[16][16][16][16];",
117-
"__shared__ float32 _B_0[16][16][16][16];",
118-
"__shared__ float32 _C_0[16][16][16][16];"};
116+
auto declarations = {"__shared__ float _A_0[16][16][16][16];",
117+
"__shared__ float _B_0[16][16][16][16];",
118+
"__shared__ float _C_0[16][16][16][16];"};
119119

120120
auto copyA =
121121
"_A_0[c4][c5][c6][c7] = A[16 * b0 + c4][16 * b1 + c5][c2 + c6][c3 + c7];";
@@ -161,9 +161,9 @@ TEST_F(Sum4D, CodeOuterBand) {
161161
* promoteEverythingAt does not call mapCopiesToThreads.
162162
*/
163163
TEST_F(Sum4D, CodeAboveThreadMapping) {
164-
auto declarations = {"__shared__ float32 _A_0[16][16][16][16];",
165-
"__shared__ float32 _B_0[16][16][16][16];",
166-
"__shared__ float32 _C_0[16][16][16][16];"};
164+
auto declarations = {"__shared__ float _A_0[16][16][16][16];",
165+
"__shared__ float _B_0[16][16][16][16];",
166+
"__shared__ float _C_0[16][16][16][16];"};
167167
auto copyA =
168168
"_A_0[c4][c5][c6][c7] = A[16 * b0 + c4][16 * b1 + c5][c2 + c6][c3 + c7]";
169169
auto copyB =
@@ -204,9 +204,9 @@ TEST_F(Sum4D, CodeAboveThreadMapping) {
204204
}
205205

206206
TEST_F(Sum4D, CodeInnerBand) {
207-
auto declarations = {"__shared__ float32 _C_0[1][1][1][1];",
208-
"__shared__ float32 _A_0[1][1][1][1];",
209-
"__shared__ float32 _B_0[1][1][1][1];"};
207+
auto declarations = {"__shared__ float _C_0[1][1][1][1];",
208+
"__shared__ float _A_0[1][1][1][1];",
209+
"__shared__ float _B_0[1][1][1][1];"};
210210
auto copyA =
211211
"_A_0[0][0][0][0] = A[16 * b0 + c4][16 * b1 + c5][c2 + c6][t0 + c3];";
212212
auto copyB =
@@ -473,9 +473,9 @@ def fun(float(N,K) A, float(K,M) B, float(N,M) C) -> (O) {
473473
}
474474

475475
void expectNoABCPromotion(const std::string& code) {
476-
auto aDeclPos = code.find(" float32 _A_0");
477-
auto bDeclPos = code.find(" float32 _B_0");
478-
auto cDeclPos = code.find(" float32 _C_0");
476+
auto aDeclPos = code.find(" float _A_0");
477+
auto bDeclPos = code.find(" float _B_0");
478+
auto cDeclPos = code.find(" float _C_0");
479479
EXPECT_TRUE(aDeclPos == std::string::npos)
480480
<< "tensor A promoted to register but has elements accessed "
481481
<< "by multiple threads";
@@ -487,7 +487,7 @@ def fun(float(N,K) A, float(K,M) B, float(N,M) C) -> (O) {
487487
}
488488

489489
void expectFourOElementsPromoted(const std::string& code) {
490-
auto oDeclPos = code.find("float32 _O_0[4][1];");
490+
auto oDeclPos = code.find("float _O_0[4][1];");
491491
EXPECT_TRUE(oDeclPos != std::string::npos)
492492
<< "expected O to be promoted to registers";
493493

@@ -541,7 +541,7 @@ TEST_F(MatMulBias, RegisterPromotion) {
541541
.usePrivateMemory(true);
542542

543543
auto code = emitCode({{"N", 42}, {"M", 56}, {"K", 37}}, mappingOptions);
544-
auto declPos = code.find("float32 _O_0");
544+
auto declPos = code.find("float _O_0");
545545
auto copyToPos =
546546
code.find("_O_0[0][0] = O[32 * b0 + c3][t0 + 32 * b1]", declPos + 1);
547547
auto copyFromPos =
@@ -570,7 +570,7 @@ TEST_F(MatMulBias, RegisterPromotionSharedPreference) {
570570

571571
auto code = emitCode({{"N", 42}, {"M", 56}, {"K", 37}}, mappingOptions);
572572

573-
auto declPos = code.find("float32 _O_0[1][1]");
573+
auto declPos = code.find("float _O_0[1][1]");
574574
EXPECT_TRUE(declPos == std::string::npos)
575575
<< "not expected promotion to register because promoted to shared";
576576

@@ -606,7 +606,7 @@ TEST_F(MatMulBias, RegistersAtRootNotEnoughUnroll) {
606606
auto mscop = prepare({{"N", 42}, {"M", 56}, {"K", 37}}, mappingOptions);
607607
promoteToRegistersBelow(*mscop, mscop->scop().scheduleRoot());
608608
auto code = emitCode(mscop);
609-
auto oDeclPos = code.find("float32 _O_0;");
609+
auto oDeclPos = code.find("float _O_0;");
610610

611611
EXPECT_TRUE(oDeclPos == std::string::npos)
612612
<< "not expected O to be promoted to registers";

0 commit comments

Comments
 (0)