Skip to content

Commit 88e9941

Browse files
authored
Merge pull request #4354 from imaginationtech/img-rvv-kernel-generator
[RISC-V] Improve RVV kernel generator LMUL usage
2 parents e3508d3 + 4a12cf5 commit 88e9941

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

kernel/riscv64/generate_kernel.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,13 @@ def generate_gemm_kernel_inner_complex( settings, dest, M, N, vlen, a_regs ):
197197
dest.write("ai += {M}*2;")
198198
dest.write()
199199

200-
201-
accumulation_regs = a_regs * N * settings['LMUL_ACC'].value
200+
# for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
201+
accumulation_regs = a_regs * N
202202
dest.write("// {a_regs} vector regs to hold A array contents, {accumulation_regs} regs to hold values accumulated over k",
203203
a_regs=a_regs*2, accumulation_regs=accumulation_regs*2
204204
)
205205
pass_regs = (accumulation_regs + a_regs)*2
206-
tmp_regs = 32-pass_regs
206+
tmp_regs = (32 // settings['LMUL_ACC'].value) - pass_regs
207207
if tmp_regs < 2:
208208
raise RuntimeError("Complex kernel would use too many registers!")
209209

@@ -337,10 +337,12 @@ def generate_gemm_kernel( settings, OUTPUT ):
337337

338338
M = settings['M'].value
339339
N = settings['N'].value
340-
vlenmax = int( settings['reg_width_bits'].value / settings['ELEN_PARAM'].value )
340+
vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value /
341+
settings['ELEN_PARAM'].value)
341342
a_regs = max(int(M/vlenmax), 1)
342343

343-
accumulation_regs = a_regs * N * settings['LMUL_ACC'].value
344+
# for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
345+
accumulation_regs = a_regs * N
344346
required_regs = accumulation_regs + a_regs
345347
if is_complex:
346348
required_regs = required_regs * 2 + 2
@@ -380,9 +382,9 @@ def generate_gemm_kernel( settings, OUTPUT ):
380382
'''.format(tail_policy=settings['tail_policy'].value))
381383

382384

383-
if required_regs > 32:
384-
raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only 32 are available".format(
385-
required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else '')
385+
if required_regs > (32 // settings['LMUL_ACC'].value):
386+
raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only {} are available".format(
387+
required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else ''), 32 // settings['LMUL_ACC'].value
386388
))
387389

388390
TRMM = (settings['op'].value == 'trmm')
@@ -448,7 +450,8 @@ def generate_gemm_kernel( settings, OUTPUT ):
448450
def generate_M_tails( dest, settings, M, N ):
449451
M_tail = int(M/2)
450452
M_tail_min = settings['M_tail_scalar_from'].value
451-
vlenmax = int( settings['reg_width_bits'].value / settings['ELEN_PARAM'].value )
453+
vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value
454+
/ settings['ELEN_PARAM'].value )
452455
TRMM = (settings['op'].value == 'trmm')
453456
is_complex = settings['complex'].value
454457
generate_gemm_kernel_inner = generate_gemm_kernel_inner_complex if is_complex else generate_gemm_kernel_inner_real
@@ -667,4 +670,4 @@ def OUTPUT(*args, **kwargs):
667670
ERROR("unsupported kernel type {}".format(settings['op']))
668671

669672
if __name__ == "__main__":
670-
main()
673+
main()

0 commit comments

Comments
 (0)