@@ -182,8 +182,8 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
182
182
R"RES( int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
183
183
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
184
184
float32 (*C)[M] = reinterpret_cast<float32 (*)[M]>(pC);
185
- float32 (*A)[M] = reinterpret_cast<float32 (*)[M]>(pA);
186
- float32 (*B)[M] = reinterpret_cast<float32 (*)[M]>(pB);
185
+ const float32 (*A)[M] = reinterpret_cast<const float32 (*)[M]>(pA);
186
+ const float32 (*B)[M] = reinterpret_cast<const float32 (*)[M]>(pB);
187
187
for (int c1 = 16 * b1; c1 < M; c1 += 4096) {
188
188
if (M >= t1 + c1 + 1) {
189
189
C[(t0 + 16 * b0)][(t1 + c1)] = (A[(t0 + 16 * b0)][(t1 + c1)] + B[(t0 + 16 * b0)][(t1 + c1)]);
@@ -219,10 +219,10 @@ def fun(float(N, N, N, N) A, float(N, N) B, float(N, N) C, float(N, N) D)
219
219
float32 (*O1)[N] = reinterpret_cast<float32 (*)[N]>(pO1);
220
220
float32 (*O2)[N] = reinterpret_cast<float32 (*)[N]>(pO2);
221
221
float32 (*O3)[N] = reinterpret_cast<float32 (*)[N]>(pO3);
222
- float32 (*A)[N][N][N] = reinterpret_cast<float32 (*)[N][N][N]>(pA);
223
- float32 (*B)[N] = reinterpret_cast<float32 (*)[N]>(pB);
224
- float32 (*C)[N] = reinterpret_cast<float32 (*)[N]>(pC);
225
- float32 (*D)[N] = reinterpret_cast<float32 (*)[N]>(pD);
222
+ const float32 (*A)[N][N][N] = reinterpret_cast<const float32 (*)[N][N][N]>(pA);
223
+ const float32 (*B)[N] = reinterpret_cast<const float32 (*)[N]>(pB);
224
+ const float32 (*C)[N] = reinterpret_cast<const float32 (*)[N]>(pC);
225
+ const float32 (*D)[N] = reinterpret_cast<const float32 (*)[N]>(pD);
226
226
for (int c0 = 0; c0 < N; c0 += 1) {
227
227
for (int c1 = 0; c1 < N; c1 += 1) {
228
228
O1[c0][c1] = 0.000000f;
@@ -261,11 +261,11 @@ def fun(float(N, N) A) -> (O)
261
261
auto res = std::get<0 >(mscop->codegen (specializedName));
262
262
263
263
string expected (
264
- R"RES( __global__ void kernel_anon(int32 N, float32* pO, float32* pA) {
264
+ R"RES( __global__ void kernel_anon(int32 N, float32* pO, const float32* pA) {
265
265
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
266
266
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
267
267
float32 (*O)[N] = reinterpret_cast<float32 (*)[N]>(pO);
268
- float32 (*A)[N] = reinterpret_cast<float32 (*)[N]>(pA);
268
+ const float32 (*A)[N] = reinterpret_cast<const float32 (*)[N]>(pA);
269
269
for (int c0 = 0; c0 < N; c0 += 1) {
270
270
for (int c1 = 0; c1 < N; c1 += 1) {
271
271
O[c0][c1] = (((A[c0][c1] + float32(c0)) + float32(c1)) + float32(N));
@@ -290,13 +290,13 @@ def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O)
290
290
auto res = std::get<0 >(mscop->codegen (specializedName));
291
291
292
292
string expected =
293
- R"RES( __global__ void kernel_anon(int32 N, float32* pO, float32* pA, float32* pB, float32* pC) {
293
+ R"RES( __global__ void kernel_anon(int32 N, float32* pO, const float32* pA, const float32* pB, const float32* pC) {
294
294
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
295
295
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
296
296
float32 (*O)[512] = reinterpret_cast<float32 (*)[512]>(pO);
297
- float32 (*A)[512] = reinterpret_cast<float32 (*)[512]>(pA);
298
- float32 (*B)[512] = reinterpret_cast<float32 (*)[512]>(pB);
299
- float32 (*C) = reinterpret_cast<float32 (*)>(pC);
297
+ const float32 (*A)[512] = reinterpret_cast<const float32 (*)[512]>(pA);
298
+ const float32 (*B)[512] = reinterpret_cast<const float32 (*)[512]>(pB);
299
+ const float32 (*C) = reinterpret_cast<const float32 (*)>(pC);
300
300
for (int c0 = 0; c0 <= 511; c0 += 1) {
301
301
for (int c1 = 0; c1 <= 511; c1 += 1) {
302
302
O[c0][c1] = (nextafter(C[c0], exp(A[c0][c1])) + log(B[c1][c0]));
@@ -312,8 +312,8 @@ constexpr auto kExpectedMatmul_64_64_64 =
312
312
R"CUDA( int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
313
313
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
314
314
float32 (*O)[64] = reinterpret_cast<float32 (*)[64]>(pO);
315
- float32 (*A)[64] = reinterpret_cast<float32 (*)[64]>(pA);
316
- float32 (*B)[64] = reinterpret_cast<float32 (*)[64]>(pB);
315
+ const float32 (*A)[64] = reinterpret_cast<const float32 (*)[64]>(pA);
316
+ const float32 (*B)[64] = reinterpret_cast<const float32 (*)[64]>(pB);
317
317
for (int c0 = 0; c0 <= 63; c0 += 16) {
318
318
for (int c1 = 0; c1 <= 63; c1 += 16) {
319
319
for (int c2 = t1; c2 <= 15; c2 += 8) {
0 commit comments