@@ -100,7 +100,7 @@ def load( # pylint: disable=too-many-branches,too-many-statements
100
100
kernel_map = np .full ((tc .dev .MAX_PROC , tc .dev .MASK_WIDTH_LARGE ),
101
101
fill_value = _INVALID_VALUE , dtype = np .int64 )
102
102
kernels_used = np .zeros ((tc .dev .MAX_PROC , tc .dev .MASK_WIDTH_LARGE ), dtype = np .int64 )
103
- kernel_data = np .zeros ((tc .dev .MAX_PROC , tc .dev .MASK_WIDTH_LARGE , 9 ), dtype = np .int8 )
103
+ kernel_data = np .zeros ((tc .dev .MAX_PROC , tc .dev .MASK_WIDTH_LARGE , 9 ), dtype = np .uint8 )
104
104
# There are four 32-bit words per 9-byte kernel.
105
105
# The value map is initialized with zeros so we can later ignore unused entries and use
106
106
# memcpy() on initialized and uninitialized data.
@@ -438,6 +438,8 @@ def add_kernel_data(ll, p, col_target, b):
438
438
kernel_map [p ][col ] = ll
439
439
440
440
assert kernels_used [p ][col ] <= 8
441
+ assert isinstance (b , np .int64 ), f'Kernel is type { type (b )} instead of numpy.int64'
442
+ assert 0 <= b <= 255 , f'Trying to add kernel value { b } '
441
443
kernel_data [p ][col ][8 - kernels_used [p ][col ]] = b & 0xff
442
444
kernels_used [p ][col ] += 1
443
445
@@ -457,7 +459,7 @@ def add_kernel_data(ll, p, col_target, b):
457
459
col_target , col_bytes = divmod (start_col * ksize * in_exp , 9 )
458
460
# Pad out the leftovers
459
461
for _ in range (col_bytes // qfactor ): # FIXME for quantization
460
- col_target = add_kernel_data (ll , p , col_target , 0 )
462
+ col_target = add_kernel_data (ll , p , col_target , np . int64 ( 0 ) )
461
463
462
464
out_range = out_expand [ll ] if conv_groups [ll ] == 1 else 1
463
465
for expand in range (out_range ):
@@ -506,8 +508,10 @@ def add_kernel_data(ll, p, col_target, b):
506
508
& (2 ** abs (quantization [ll ])- 1 )
507
509
if not flatten [ll ]:
508
510
k |= this_kern << (i * abs (quantization [ll ]))
509
- else :
511
+ elif len ( k ) > 0 :
510
512
k = np .append (k , this_kern )
513
+ else :
514
+ k = this_kern
511
515
n += 1
512
516
mask >>= 1
513
517
if debug :
@@ -525,8 +529,8 @@ def add_kernel_data(ll, p, col_target, b):
525
529
),
526
530
)
527
531
for i in range (0 , len (k ) // qfactor ):
528
- e = 0
529
- for j in range (qfactor ):
532
+ e = k [ i * qfactor ]
533
+ for j in range (1 , qfactor ):
530
534
e |= k [i * qfactor + j ] << (j * abs (quantization [ll ]))
531
535
col_target = add_kernel_data (ll , p , col_target , e )
532
536
else :
@@ -536,7 +540,7 @@ def add_kernel_data(ll, p, col_target, b):
536
540
537
541
else : # When expanding, need to pad with zero kernels if needed
538
542
for _ in range (ksize // qfactor ):
539
- col_target = add_kernel_data (ll , p , col_target , 0 )
543
+ col_target = add_kernel_data (ll , p , col_target , np . int64 ( 0 ) )
540
544
541
545
# Consume kernels
542
546
if not flatten [ll ]:
@@ -552,7 +556,7 @@ def add_kernel_data(ll, p, col_target, b):
552
556
and kernels_used [p ][kern_offs [ll ] + col_target ] > 0 : # Partials
553
557
col_target += 1
554
558
while col_target - start_col < kern_len [ll ]:
555
- col_target = add_kernel_data (ll , p , col_target , 0 )
559
+ col_target = add_kernel_data (ll , p , col_target , np . int64 ( 0 ) )
556
560
if flatten [ll ]:
557
561
kern_len [ll ] = col_target
558
562
elif not state .new_kernel_loader :
0 commit comments