Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit f418bfa

Browse files
Merge pull request #248 from nicolasvasilache/pr/cuda-input-const-prt
Use const pointers for input tensors in generated cuda
2 parents aaf96f2 + cec3512 commit f418bfa

File tree

4 files changed

+46
-23
lines changed

4 files changed

+46
-23
lines changed

include/tc/lang/sema.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ static inline TreeRef match_types(TreeRef a, TreeRef b) {
155155
/// - replace TK_APPLY with TK_BUILT_IN for built in functions
156156
/// - checks that all variables are defined, and creates index/reduction
157157
/// variable objects.
158+
/// - checks that input variables are readonly.
158159
struct Sema {
159160
std::unordered_map<TreeRef, TreeRef> expr_to_type;
160161

@@ -349,10 +350,13 @@ struct Sema {
349350
}
350351
}
351352

352-
for (auto p : func.params())
353+
for (auto p : func.params()) {
353354
nonTemporaries.insert(p.ident().name());
354-
for (auto r : func.returns())
355+
inputParameters.insert(p.ident().name());
356+
}
357+
for (auto r : func.returns()) {
355358
nonTemporaries.insert(r.ident().name());
359+
}
356360

357361
auto statements_ =
358362
checkList(func.statements(), [&](TreeRef r) { return checkStmt(r); });
@@ -445,6 +449,9 @@ struct Sema {
445449

446450
// make dimension variables for each dimension of the output tensor
447451
std::string name = stmt.ident().name();
452+
if (inputParameters.count(name) > 0) {
453+
throw ErrorReport(stmt_) << "TC inputs are immutable";
454+
}
448455
TreeList output_indices;
449456
int n = stmt.indices().size();
450457
for (int i = 0; i < n; ++i) {
@@ -614,6 +621,7 @@ struct Sema {
614621
// allowed
615622
std::unordered_set<std::string> live_input_names;
616623

624+
std::unordered_set<std::string> inputParameters;
617625
std::unordered_set<std::string> nonTemporaries;
618626
};
619627
} // namespace lang

src/core/polyhedral/cuda/codegen.cc

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,12 @@ vector<string> emitParams(const Scop& scop) {
108108
}
109109

110110
// Returns number of names printed, i.e. tensors.size().
111-
string emitTypedTensorName(Halide::OutputImageParam t) {
111+
string emitTypedTensorName(
112+
Halide::OutputImageParam t,
113+
bool constInput = false) {
112114
stringstream ss;
113-
ss << t.type() << "* " << makePointerName(t.name());
115+
ss << (constInput ? "const " : "") << t.type() << "* "
116+
<< makePointerName(t.name());
114117
return ss.str();
115118
}
116119

@@ -128,7 +131,7 @@ vector<string> emitTypedTensorNames(const vector<Halide::ImageParam>& tensors) {
128131
vector<string> res;
129132
res.reserve(tensors.size());
130133
for (auto t : tensors) {
131-
res.push_back(emitTypedTensorName(t));
134+
res.push_back(emitTypedTensorName(t, true));
132135
}
133136
return res;
134137
}
@@ -179,7 +182,8 @@ void emitKernelSignature(
179182
void emitTensorView(
180183
stringstream& ss,
181184
Halide::OutputImageParam p,
182-
const map<string, Halide::Expr>& paramValues) {
185+
const map<string, Halide::Expr>& paramValues,
186+
bool constInput = false) {
183187
WS ws;
184188
stringstream ssViewType;
185189
for (int i = 1; i < p.dimensions(); ++i) { // Skip the outermost dimension
@@ -190,9 +194,11 @@ void emitTensorView(
190194
ssViewType << "[" << extent << "]";
191195
}
192196
ss << ws.tab();
193-
ss << p.type() << " (*" << p.name() << ")" << ssViewType.str();
197+
ss << (constInput ? "const " : "") << p.type() << " (*" << p.name() << ")"
198+
<< ssViewType.str();
194199
ss << " = ";
195-
ss << "reinterpret_cast<" << p.type() << " (*)" << ssViewType.str() << ">";
200+
ss << "reinterpret_cast<" << (constInput ? "const " : "") << p.type()
201+
<< " (*)" << ssViewType.str() << ">";
196202
ss << "(" << makePointerName(p.name()) << ")";
197203
ss << ";";
198204
ss << endl;
@@ -212,7 +218,7 @@ void emitTensorViews(
212218
const vector<Halide::ImageParam>& params,
213219
const map<string, Halide::Expr>& paramValues) {
214220
for (auto p : params) {
215-
emitTensorView(ss, p, paramValues);
221+
emitTensorView(ss, p, paramValues, true);
216222
}
217223
}
218224

test/test_mapper.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
182182
R"RES(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
183183
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
184184
float32 (*C)[M] = reinterpret_cast<float32 (*)[M]>(pC);
185-
float32 (*A)[M] = reinterpret_cast<float32 (*)[M]>(pA);
186-
float32 (*B)[M] = reinterpret_cast<float32 (*)[M]>(pB);
185+
const float32 (*A)[M] = reinterpret_cast<const float32 (*)[M]>(pA);
186+
const float32 (*B)[M] = reinterpret_cast<const float32 (*)[M]>(pB);
187187
for (int c1 = 16 * b1; c1 < M; c1 += 4096) {
188188
if (M >= t1 + c1 + 1) {
189189
C[(t0 + 16 * b0)][(t1 + c1)] = (A[(t0 + 16 * b0)][(t1 + c1)] + B[(t0 + 16 * b0)][(t1 + c1)]);
@@ -219,10 +219,10 @@ def fun(float(N, N, N, N) A, float(N, N) B, float(N, N) C, float(N, N) D)
219219
float32 (*O1)[N] = reinterpret_cast<float32 (*)[N]>(pO1);
220220
float32 (*O2)[N] = reinterpret_cast<float32 (*)[N]>(pO2);
221221
float32 (*O3)[N] = reinterpret_cast<float32 (*)[N]>(pO3);
222-
float32 (*A)[N][N][N] = reinterpret_cast<float32 (*)[N][N][N]>(pA);
223-
float32 (*B)[N] = reinterpret_cast<float32 (*)[N]>(pB);
224-
float32 (*C)[N] = reinterpret_cast<float32 (*)[N]>(pC);
225-
float32 (*D)[N] = reinterpret_cast<float32 (*)[N]>(pD);
222+
const float32 (*A)[N][N][N] = reinterpret_cast<const float32 (*)[N][N][N]>(pA);
223+
const float32 (*B)[N] = reinterpret_cast<const float32 (*)[N]>(pB);
224+
const float32 (*C)[N] = reinterpret_cast<const float32 (*)[N]>(pC);
225+
const float32 (*D)[N] = reinterpret_cast<const float32 (*)[N]>(pD);
226226
for (int c0 = 0; c0 < N; c0 += 1) {
227227
for (int c1 = 0; c1 < N; c1 += 1) {
228228
O1[c0][c1] = 0.000000f;
@@ -261,11 +261,11 @@ def fun(float(N, N) A) -> (O)
261261
auto res = std::get<0>(mscop->codegen(specializedName));
262262

263263
string expected(
264-
R"RES(__global__ void kernel_anon(int32 N, float32* pO, float32* pA) {
264+
R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA) {
265265
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
266266
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
267267
float32 (*O)[N] = reinterpret_cast<float32 (*)[N]>(pO);
268-
float32 (*A)[N] = reinterpret_cast<float32 (*)[N]>(pA);
268+
const float32 (*A)[N] = reinterpret_cast<const float32 (*)[N]>(pA);
269269
for (int c0 = 0; c0 < N; c0 += 1) {
270270
for (int c1 = 0; c1 < N; c1 += 1) {
271271
O[c0][c1] = (((A[c0][c1] + float32(c0)) + float32(c1)) + float32(N));
@@ -290,13 +290,13 @@ def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O)
290290
auto res = std::get<0>(mscop->codegen(specializedName));
291291

292292
string expected =
293-
R"RES(__global__ void kernel_anon(int32 N, float32* pO, float32* pA, float32* pB, float32* pC) {
293+
R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA, const float32* pB, const float32* pC) {
294294
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
295295
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
296296
float32 (*O)[512] = reinterpret_cast<float32 (*)[512]>(pO);
297-
float32 (*A)[512] = reinterpret_cast<float32 (*)[512]>(pA);
298-
float32 (*B)[512] = reinterpret_cast<float32 (*)[512]>(pB);
299-
float32 (*C) = reinterpret_cast<float32 (*)>(pC);
297+
const float32 (*A)[512] = reinterpret_cast<const float32 (*)[512]>(pA);
298+
const float32 (*B)[512] = reinterpret_cast<const float32 (*)[512]>(pB);
299+
const float32 (*C) = reinterpret_cast<const float32 (*)>(pC);
300300
for (int c0 = 0; c0 <= 511; c0 += 1) {
301301
for (int c1 = 0; c1 <= 511; c1 += 1) {
302302
O[c0][c1] = (nextafter(C[c0], exp(A[c0][c1])) + log(B[c1][c0]));
@@ -312,8 +312,8 @@ constexpr auto kExpectedMatmul_64_64_64 =
312312
R"CUDA(int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
313313
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
314314
float32 (*O)[64] = reinterpret_cast<float32 (*)[64]>(pO);
315-
float32 (*A)[64] = reinterpret_cast<float32 (*)[64]>(pA);
316-
float32 (*B)[64] = reinterpret_cast<float32 (*)[64]>(pB);
315+
const float32 (*A)[64] = reinterpret_cast<const float32 (*)[64]>(pA);
316+
const float32 (*B)[64] = reinterpret_cast<const float32 (*)[64]>(pB);
317317
for (int c0 = 0; c0 <= 63; c0 += 16) {
318318
for (int c1 = 0; c1 <= 63; c1 += 16) {
319319
for (int c2 = t1; c2 <= 15; c2 += 8) {

test/test_tc2halide.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,15 @@ def fun(float(M, N) I) -> (O1, O2, O3) {
197197
Check(tc, {123, 13});
198198
}
199199

200+
TEST_F(TC2Isl, MutableInput) {
201+
string tc = R"TC(
202+
def foo(float(N) A) -> (B) {
203+
A(i) = A(i) + 42
204+
B(k) +=! A(i) where k in 0:1
205+
}
206+
)TC";
207+
EXPECT_THROW(Check(tc, {123}), ::lang::ErrorReport);
208+
}
200209
int main(int argc, char** argv) {
201210
::testing::InitGoogleTest(&argc, argv);
202211
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)