Skip to content

Commit 4a12cf5

Browse files
committed
[RISC-V] Improve RVV kernel generator LMUL usage
The RVV kernel generation script uses the provided LMUL to increase the number of accumulator registers. Since the effect of the LMUL is to group together the vector registers into larger ones, it actually should be used as a multiplier in the calculation of vlenmax. At the moment, no matter what LMUL is provided, the generated kernels would only set the maximum number of vector elements equal to VLEN/SEW. Commit changes the use of LMUL to properly adjust vlenmax. Note that an increase in LMUL results in a decrease in the number of effective vector registers.
1 parent 62f0f50 commit 4a12cf5

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

kernel/riscv64/generate_kernel.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,13 @@ def generate_gemm_kernel_inner_complex( settings, dest, M, N, vlen, a_regs ):
197197
dest.write("ai += {M}*2;")
198198
dest.write()
199199

200-
201-
accumulation_regs = a_regs * N * settings['LMUL_ACC'].value
200+
# for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
201+
accumulation_regs = a_regs * N
202202
dest.write("// {a_regs} vector regs to hold A array contents, {accumulation_regs} regs to hold values accumulated over k",
203203
a_regs=a_regs*2, accumulation_regs=accumulation_regs*2
204204
)
205205
pass_regs = (accumulation_regs + a_regs)*2
206-
tmp_regs = 32-pass_regs
206+
tmp_regs = (32 // settings['LMUL_ACC'].value) - pass_regs
207207
if tmp_regs < 2:
208208
raise RuntimeError("Complex kernel would use too many registers!")
209209

@@ -337,10 +337,12 @@ def generate_gemm_kernel( settings, OUTPUT ):
337337

338338
M = settings['M'].value
339339
N = settings['N'].value
340-
vlenmax = int( settings['reg_width_bits'].value / settings['ELEN_PARAM'].value )
340+
vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value /
341+
settings['ELEN_PARAM'].value)
341342
a_regs = max(int(M/vlenmax), 1)
342343

343-
accumulation_regs = a_regs * N * settings['LMUL_ACC'].value
344+
# for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
345+
accumulation_regs = a_regs * N
344346
required_regs = accumulation_regs + a_regs
345347
if is_complex:
346348
required_regs = required_regs * 2 + 2
@@ -380,9 +382,9 @@ def generate_gemm_kernel( settings, OUTPUT ):
380382
'''.format(tail_policy=settings['tail_policy'].value))
381383

382384

383-
if required_regs > 32:
384-
raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only 32 are available".format(
385-
required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else '')
385+
if required_regs > (32 // settings['LMUL_ACC'].value):
386+
raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only {} are available".format(
387+
required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else ''), 32 // settings['LMUL_ACC'].value
386388
))
387389

388390
TRMM = (settings['op'].value == 'trmm')
@@ -448,7 +450,8 @@ def generate_gemm_kernel( settings, OUTPUT ):
448450
def generate_M_tails( dest, settings, M, N ):
449451
M_tail = int(M/2)
450452
M_tail_min = settings['M_tail_scalar_from'].value
451-
vlenmax = int( settings['reg_width_bits'].value / settings['ELEN_PARAM'].value )
453+
vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value
454+
/ settings['ELEN_PARAM'].value )
452455
TRMM = (settings['op'].value == 'trmm')
453456
is_complex = settings['complex'].value
454457
generate_gemm_kernel_inner = generate_gemm_kernel_inner_complex if is_complex else generate_gemm_kernel_inner_real
@@ -667,4 +670,4 @@ def OUTPUT(*args, **kwargs):
667670
ERROR("unsupported kernel type {}".format(settings['op']))
668671

669672
if __name__ == "__main__":
670-
main()
673+
main()

0 commit comments

Comments
 (0)