Skip to content

Commit 295d4d5

Browse files
committed
lint fix
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
1 parent 87dbab7 commit 295d4d5

File tree

1 file changed

+69
-68
lines changed

1 file changed

+69
-68
lines changed

csrc/rocm/skinny_gemms.cu

Lines changed: 69 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
275275
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
276276
const scalar_t* __restrict__ A, scalar_t* C,
277277
const int _WvPrGrp, const int CuCount) {
278-
279-
#if defined(__HIP__MI300__)
278+
#if defined(__HIP__MI300__)
280279
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
281-
#else
280+
#else
282281
constexpr bool use_mfma = false;
283-
#endif
282+
#endif
284283

285284
using scalar8 =
286285
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
@@ -389,7 +388,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
389388
if (k_ >= K) break;
390389

391390
const scalar_t* B_ = &B[(m + 0) * K + k_];
392-
for (int y=0; y<YTILE; y++)
391+
for (int y = 0; y < YTILE; y++)
393392
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K])));
394393
}
395394

@@ -418,16 +417,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
418417
#pragma unroll
419418
for (uint32_t n = 0; n < N; n++) {
420419
#pragma unroll
421-
for (int y=0; y<YTILE; y++) {
420+
for (int y = 0; y < YTILE; y++) {
422421
if constexpr (!use_mfma)
423422
#pragma unroll
424-
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
423+
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
425424
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
426-
}
425+
}
427426
else
428427
#pragma unroll
429428
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
430-
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
429+
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
430+
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
431431
}
432432
}
433433
}
@@ -440,23 +440,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
440440
for (int n = 0; n < N; n++) {
441441
for (int y = 0; y < YTILE; y++) {
442442
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
443-
: "=v"(sum[n][y])
444-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
443+
: "=v"(sum[n][y])
444+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
445445
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
446-
: "=v"(sum[n][y])
447-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
446+
: "=v"(sum[n][y])
447+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
448448
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
449-
: "=v"(sum[n][y])
450-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
449+
: "=v"(sum[n][y])
450+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
451451
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
452-
: "=v"(sum[n][y])
453-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
452+
: "=v"(sum[n][y])
453+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
454454
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
455-
: "=v"(sum[n][y])
456-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
455+
: "=v"(sum[n][y])
456+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
457457
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
458-
: "=v"(sum[n][y])
459-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
458+
: "=v"(sum[n][y])
459+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
460460
}
461461
}
462462

@@ -473,9 +473,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
473473
for (int n = 0; n < N; n++) {
474474
#pragma unroll
475475
for (int y = 0; y < YTILE; y++) {
476-
//float accm1 = 0;
477-
//for (int i=0; i<64; i++)
478-
// accm1 += __shfl(sum4[n][y][i%4], i);
476+
// float accm1 = 0;
477+
// for (int i=0; i<64; i++)
478+
// accm1 += __shfl(sum4[n][y][i%4], i);
479479
float accm = sum4[n][y][0];
480480
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
481481
: "=v"(accm)
@@ -535,11 +535,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
535535
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
536536
const scalar_t* __restrict__ A, scalar_t* C,
537537
const int _WvPrGrp, const int CuCount) {
538-
#if defined(__HIP__MI300__)
538+
#if defined(__HIP__MI300__)
539539
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
540-
#else
540+
#else
541541
constexpr bool use_mfma = false;
542-
#endif
542+
#endif
543543

544544
using scalar8 =
545545
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
@@ -672,7 +672,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
672672
if (k_ >= K) break;
673673

674674
const scalar_t* B_ = &B[(m + 0) * K + k_];
675-
for (int b=0; b<YTILE; b++)
675+
for (int b = 0; b < YTILE; b++)
676676
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
677677
}
678678

@@ -704,16 +704,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
704704
// Do the matrix multiplication of activation and weight matrix
705705
// - Remember the accumulation is happening for K-split of 64!
706706
#pragma unroll
707-
for (int y=0; y<YTILE; y++) {
707+
for (int y = 0; y < YTILE; y++) {
708708
if constexpr (!use_mfma)
709709
#pragma unroll
710710
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
711711
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
712-
}
712+
}
713713
else
714714
#pragma unroll
715715
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
716-
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
716+
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
717+
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
717718
}
718719
}
719720
}
@@ -726,23 +727,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
726727
for (int n = 0; n < N; n++) {
727728
for (int y = 0; y < YTILE; y++) {
728729
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
729-
: "=v"(sum[n][y])
730-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
730+
: "=v"(sum[n][y])
731+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
731732
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
732-
: "=v"(sum[n][y])
733-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
733+
: "=v"(sum[n][y])
734+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
734735
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
735-
: "=v"(sum[n][y])
736-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
736+
: "=v"(sum[n][y])
737+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
737738
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
738-
: "=v"(sum[n][y])
739-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
739+
: "=v"(sum[n][y])
740+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
740741
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
741-
: "=v"(sum[n][y])
742-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
742+
: "=v"(sum[n][y])
743+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
743744
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
744-
: "=v"(sum[n][y])
745-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
745+
: "=v"(sum[n][y])
746+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
746747
}
747748
}
748749

@@ -759,9 +760,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
759760
for (int n = 0; n < N; n++) {
760761
#pragma unroll
761762
for (int y = 0; y < YTILE; y++) {
762-
//float accm1 = 0;
763-
//for (int i=0; i<64; i++)
764-
// accm1 += __shfl(sum4[n][y][i%4], i);
763+
// float accm1 = 0;
764+
// for (int i=0; i<64; i++)
765+
// accm1 += __shfl(sum4[n][y][i%4], i);
765766

766767
float accm = sum4[n][y][0];
767768
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
@@ -834,11 +835,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
834835
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
835836
const scalar_t* __restrict__ A, scalar_t* C,
836837
const int _WvPrGrp, const int CuCount) {
837-
#if defined(__HIP__MI300__)
838+
#if defined(__HIP__MI300__)
838839
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
839-
#else
840+
#else
840841
constexpr bool use_mfma = false;
841-
#endif
842+
#endif
842843

843844
using scalar8 =
844845
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
@@ -961,7 +962,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
961962
if constexpr (!use_mfma)
962963
sum[n][i] = 0;
963964
else
964-
sum4[n][i] = {0,0,0,0};
965+
sum4[n][i] = {0, 0, 0, 0};
965966

966967
bigType bigA[N][UNRL];
967968
bigType bigB[YTILE][UNRL];
@@ -1010,7 +1011,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
10101011
if (k_ >= K) break;
10111012

10121013
const scalar_t* B_ = &B[(m + 0) * K + k_];
1013-
for (int b=0; b<YTILE; b++)
1014+
for (int b = 0; b < YTILE; b++)
10141015
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
10151016
}
10161017

@@ -1046,16 +1047,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
10461047
// Do the matrix multiplication of activation and weight matrix
10471048
// - Remember the accumulation is happening for K-split of 64!
10481049
#pragma unroll
1049-
for (int y=0; y<YTILE; y++) {
1050+
for (int y = 0; y < YTILE; y++) {
10501051
if constexpr (!use_mfma)
10511052
#pragma unroll
10521053
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
10531054
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
1054-
}
1055-
else
1055+
}
1056+
else
10561057
#pragma unroll
10571058
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
1058-
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
1059+
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
1060+
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
10591061
}
10601062
}
10611063
}
@@ -1076,23 +1078,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
10761078
for (int n = 0; n < N; n++) {
10771079
for (int y = 0; y < YTILE; y++) {
10781080
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
1079-
: "=v"(sum[n][y])
1080-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
1081+
: "=v"(sum[n][y])
1082+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
10811083
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
1082-
: "=v"(sum[n][y])
1083-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
1084+
: "=v"(sum[n][y])
1085+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
10841086
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
1085-
: "=v"(sum[n][y])
1086-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
1087+
: "=v"(sum[n][y])
1088+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
10871089
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
1088-
: "=v"(sum[n][y])
1089-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
1090+
: "=v"(sum[n][y])
1091+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
10901092
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
1091-
: "=v"(sum[n][y])
1092-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
1093+
: "=v"(sum[n][y])
1094+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
10931095
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
1094-
: "=v"(sum[n][y])
1095-
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
1096+
: "=v"(sum[n][y])
1097+
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
10961098
}
10971099
}
10981100

@@ -1148,7 +1150,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
11481150
}
11491151
}
11501152

1151-
11521153
m += CuCount * _WvPrGrp * YTILE;
11531154
kBase = 0;
11541155

0 commit comments

Comments
 (0)