@@ -197,13 +197,13 @@ def generate_gemm_kernel_inner_complex( settings, dest, M, N, vlen, a_regs ):
197
197
dest .write ("ai += {M}*2;" )
198
198
dest .write ()
199
199
200
-
201
- accumulation_regs = a_regs * N * settings [ 'LMUL_ACC' ]. value
200
+ # for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
201
+ accumulation_regs = a_regs * N
202
202
dest .write ("// {a_regs} vector regs to hold A array contents, {accumulation_regs} regs to hold values accumulated over k" ,
203
203
a_regs = a_regs * 2 , accumulation_regs = accumulation_regs * 2
204
204
)
205
205
pass_regs = (accumulation_regs + a_regs )* 2
206
- tmp_regs = 32 - pass_regs
206
+ tmp_regs = ( 32 // settings [ 'LMUL_ACC' ]. value ) - pass_regs
207
207
if tmp_regs < 2 :
208
208
raise RuntimeError ("Complex kernel would use too many registers!" )
209
209
@@ -337,10 +337,12 @@ def generate_gemm_kernel( settings, OUTPUT ):
337
337
338
338
M = settings ['M' ].value
339
339
N = settings ['N' ].value
340
- vlenmax = int ( settings ['reg_width_bits' ].value / settings ['ELEN_PARAM' ].value )
340
+ vlenmax = int (settings ['reg_width_bits' ].value * settings ['LMUL_ACC' ].value /
341
+ settings ['ELEN_PARAM' ].value )
341
342
a_regs = max (int (M / vlenmax ), 1 )
342
343
343
- accumulation_regs = a_regs * N * settings ['LMUL_ACC' ].value
344
+ # for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
345
+ accumulation_regs = a_regs * N
344
346
required_regs = accumulation_regs + a_regs
345
347
if is_complex :
346
348
required_regs = required_regs * 2 + 2
@@ -380,9 +382,9 @@ def generate_gemm_kernel( settings, OUTPUT ):
380
382
''' .format (tail_policy = settings ['tail_policy' ].value ))
381
383
382
384
383
- if required_regs > 32 :
384
- raise Exception ("{} vector registers needed during accumulation for unrolling {} x {}{} but only 32 are available" .format (
385
- required_regs , N , M , (" with wide accumulator" if settings ['LMUL_ACC' ].value > 1 else '' )
385
+ if required_regs > ( 32 // settings [ 'LMUL_ACC' ]. value ) :
386
+ raise Exception ("{} vector registers needed during accumulation for unrolling {} x {}{} but only {} are available" .format (
387
+ required_regs , N , M , (" with wide accumulator" if settings ['LMUL_ACC' ].value > 1 else '' ), 32 // settings [ 'LMUL_ACC' ]. value
386
388
))
387
389
388
390
TRMM = (settings ['op' ].value == 'trmm' )
@@ -448,7 +450,8 @@ def generate_gemm_kernel( settings, OUTPUT ):
448
450
def generate_M_tails ( dest , settings , M , N ):
449
451
M_tail = int (M / 2 )
450
452
M_tail_min = settings ['M_tail_scalar_from' ].value
451
- vlenmax = int ( settings ['reg_width_bits' ].value / settings ['ELEN_PARAM' ].value )
453
+ vlenmax = int (settings ['reg_width_bits' ].value * settings ['LMUL_ACC' ].value
454
+ / settings ['ELEN_PARAM' ].value )
452
455
TRMM = (settings ['op' ].value == 'trmm' )
453
456
is_complex = settings ['complex' ].value
454
457
generate_gemm_kernel_inner = generate_gemm_kernel_inner_complex if is_complex else generate_gemm_kernel_inner_real
@@ -667,4 +670,4 @@ def OUTPUT(*args, **kwargs):
667
670
ERROR ("unsupported kernel type {}" .format (settings ['op' ]))
668
671
669
672
if __name__ == "__main__" :
670
- main ()
673
+ main ()
0 commit comments