@@ -394,6 +394,40 @@ def quantize(
394
394
return model
395
395
396
396
397
+ def linear_forward_8da4w (
398
+ x , weight_int8 , scales , zeros , out_features , group_size , precision
399
+ ):
400
+ x = per_token_dynamic_quant (x )
401
+ # TODO: verify and remove following reshape code
402
+ # origin_x_size = x.size()
403
+ # x = x.reshape(-1, origin_x_size[-1])
404
+
405
+ # TODO: better API
406
+ # weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed)
407
+ n_bit = 4
408
+ quant_min = - (2 ** (n_bit - 1 ))
409
+ quant_max = 2 ** (n_bit - 1 ) - 1
410
+ w_dq = torch .ops .quantized_decomposed .dequantize_per_channel_group (
411
+ weight_int8 ,
412
+ scales ,
413
+ zeros ,
414
+ quant_min ,
415
+ quant_max ,
416
+ torch .int8 ,
417
+ group_size ,
418
+ precision ,
419
+ )
420
+
421
+ # x = x.to(torch.float16)
422
+ # w_dq = w_dq.to(torch.float16)
423
+ c = torch .nn .functional .linear (x , w_dq )
424
+
425
+ # new_shape = origin_x_size[:-1] + (out_features,)
426
+ # c = c.reshape(new_shape)
427
+
428
+ return c
429
+
430
+
397
431
class Int8DynActInt4WeightLinear (torch .nn .Module ):
398
432
__constants__ = ["in_features" , "out_features" ]
399
433
@@ -433,6 +467,7 @@ def __init__(
433
467
self .in_features = in_features
434
468
self .out_features = out_features
435
469
assert not bias , "require bias=False"
470
+ # TODO: align groupsize naming
436
471
self .group_size = group_size
437
472
# Precision of the activation which also indicates
438
473
# output precision of the dynamically quantized linear layer
@@ -469,10 +504,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
469
504
self .scales ,
470
505
self .zeros ,
471
506
self .out_features ,
472
- self .groupsize ,
507
+ self .group_size ,
473
508
self .precision ,
474
509
)
475
510
511
+
476
512
from math import gcd
477
513
from functools import reduce
478
514
@@ -630,7 +666,7 @@ def _convert_for_runtime(self, model):
630
666
model ,
631
667
self .groupsize ,
632
668
self .padding_allowed ,
633
- torch . int8 ,
669
+ self . precision ,
634
670
self .precision ,
635
671
)
636
672
return model
0 commit comments