@@ -75,7 +75,6 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
75
75
f"need input_tensor shape: { input_tensor .shape } final"
76
76
f"dim to match weight_tensor shape: { weight_tensor .shape } second dim "
77
77
)
78
-
79
78
# TODO: check groupsize quantization
80
79
# avoid circular dep, TODO: move this to a common util.py
81
80
act_mat = input_tensor
@@ -97,7 +96,6 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
97
96
y = torch .ops .aten ._weight_int4pack_mm (
98
97
act_mat .contiguous (), packed_weight , groupsize , scale_and_zero
99
98
)
100
-
101
99
# remove out_feature padding
102
100
orig_out_features = weight_tensor .shape [- 2 ]
103
101
y = y [:, :orig_out_features ]
@@ -119,7 +117,7 @@ class TensorCoreTiledLayout(Layout):
119
117
inner_k_tiles : int = 8
120
118
121
119
def pre_process (self , input : torch .Tensor ) -> torch .Tensor :
122
- orig_out_features , orig_in_features = input .shape
120
+ orig_out_features , orig_in_features = input .shape [ - 2 :]
123
121
in_features = find_multiple (orig_in_features , 1024 )
124
122
out_features = find_multiple (orig_out_features , 8 )
125
123
input = torch .nn .functional .pad (
@@ -160,18 +158,18 @@ def post_process(
160
158
zero_point : torch .Tensor ,
161
159
block_size : Tuple [int , ...],
162
160
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
163
- orig_out_features , orig_in_features = input .shape
161
+ orig_out_features , orig_in_features = input .shape [ - 2 :]
164
162
in_features = find_multiple (orig_in_features , 1024 )
165
163
out_features = find_multiple (orig_out_features , 8 )
166
164
input = torch .nn .functional .pad (
167
165
input ,
168
166
(0 , in_features - orig_in_features , 0 , out_features - orig_out_features ),
169
167
)
170
168
assert (
171
- len (block_size ) == 2
172
- ), f"TensorCoreTiledLayout only supports len(block_size) == 2, got: { block_size } "
173
- scale_pad_dim_0 = (out_features - orig_out_features ) // block_size [0 ]
174
- scale_pad_dim_1 = (in_features - orig_in_features ) // block_size [1 ]
169
+ len (block_size ) == 2 or len ( block_size ) == 3 ,
170
+ ), f"TensorCoreTiledLayout only supports len(block_size) == 2 or 3 , got: { block_size } "
171
+ scale_pad_dim_0 = (out_features - orig_out_features ) // block_size [- 2 ]
172
+ scale_pad_dim_1 = (in_features - orig_in_features ) // block_size [- 1 ]
175
173
scale = torch .nn .functional .pad (scale , (0 , scale_pad_dim_1 , 0 , scale_pad_dim_0 ))
176
174
zero_point = torch .nn .functional .pad (
177
175
zero_point , (0 , scale_pad_dim_1 , 0 , scale_pad_dim_0 )
@@ -272,11 +270,22 @@ def from_plain(
272
270
assert (
273
271
int_data .dtype == torch .int32
274
272
), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype"
275
- packed_weight = torch .ops .aten ._convert_weight_to_int4pack (
276
- int_data , _layout .inner_k_tiles
277
- )
278
- scale = scale .reshape (int_data .shape [0 ], - 1 )
279
- zero_point = zero_point .reshape (int_data .shape [0 ], - 1 )
273
+ def quant_2d (mat ):
274
+ return torch .ops .aten ._convert_weight_to_int4pack (
275
+ mat , _layout .inner_k_tiles
276
+ )
277
+ if int_data .dim () == 3 : # for moe quant
278
+ num_experts = int_data .shape [0 ]
279
+ packed_weight_list = []
280
+ for expert in range (num_experts ):
281
+ packed_weight_list .append (quant_2d (int_data [expert ]).unsqueeze (0 ))
282
+ packed_weight = torch .cat (packed_weight_list , dim = 0 )
283
+ scale = scale .reshape (int_data .shape [0 ], int_data .shape [- 2 ], - 1 )
284
+ zero_point = zero_point .reshape (int_data .shape [0 ], int_data .shape [- 2 ], - 1 )
285
+ else :
286
+ packed_weight = quant_2d (int_data )
287
+ scale = scale .reshape (int_data .shape [0 ], - 1 )
288
+ zero_point = zero_point .reshape (int_data .shape [0 ], - 1 )
280
289
from torchao .quantization .utils import pack_tinygemm_scales_and_zeros
281
290
282
291
scale_and_zero = pack_tinygemm_scales_and_zeros (scale , zero_point , scale .dtype )
@@ -336,6 +345,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
336
345
f"Not supported args for copy_ due to metadata mistach: { args [0 ], args [1 ]} "
337
346
)
338
347
348
+ if func in [aten .select .int , aten .index .Tensor ]:
349
+ assert not (func is aten .select .int and args [1 ]!= 0 ), "aten.select.int currently only has support for dim=0"
350
+ return return_and_correct_aliasing (
351
+ func ,
352
+ args ,
353
+ kwargs ,
354
+ args [0 ]._apply_fn_to_data (
355
+ lambda x : func (x , * args [1 :], ** kwargs )
356
+ ),
357
+ )
358
+
359
+
339
360
if func is aten .t .default :
340
361
"""we don't need to repack the weight and just rely on external
341
362
shape being changed and record the status of transpose/no-transpose
@@ -386,11 +407,15 @@ def block_size(self):
386
407
387
408
scale , zero = unpack_tinygemm_scales_and_zeros (self .scale_and_zero )
388
409
cur_shape = self .shape
389
- assert len (cur_shape ) == 4
410
+ if len (cur_shape ) == 5 :
411
+ ones = [1 ,1 ]
412
+ cur_shape = cur_shape [1 :]
413
+ elif len (cur_shape ) == 4 :
414
+ ones = [1 ]
390
415
inner_k_tiles = cur_shape [- 1 ] * 2
391
416
original_shape = (cur_shape [0 ] * 8 , cur_shape [1 ] * (inner_k_tiles * 16 ))
392
417
groupsize = int (original_shape [1 ] / scale .shape [- 2 ])
393
- return ( 1 , groupsize )
418
+ return tuple ([ * ones , groupsize ] )
394
419
395
420
def get_plain (self ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
396
421
from torchao .quantization .quant_primitives import (
@@ -399,35 +424,54 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
399
424
)
400
425
from torchao .quantization .utils import unpack_tinygemm_scales_and_zeros
401
426
427
+ def dequant_4d (self ):
428
+ cur_shape = self .shape
429
+ scale , zero = unpack_tinygemm_scales_and_zeros (self .scale_and_zero )
430
+ assert len (cur_shape ) == 4
431
+ inner_k_tiles = cur_shape [- 1 ] * 2
432
+ original_shape = (cur_shape [0 ] * 8 , cur_shape [1 ] * (inner_k_tiles * 16 ))
433
+ eye_shape = original_shape [1 ]
434
+ groupsize = int (original_shape [1 ] / scale .shape [- 2 ])
435
+ block_size = (1 , groupsize )
436
+ original_dtype = torch .bfloat16
437
+ assert len (block_size ) == 2 and block_size [0 ] == 1
438
+ dequantized = torch .ops .aten ._weight_int4pack_mm (
439
+ torch .eye (eye_shape , device = self .device , dtype = original_dtype ),
440
+ self .packed_weight ,
441
+ groupsize ,
442
+ self .scale_and_zero ,
443
+ )
444
+ dequantized = dequantized .t ().contiguous ()
445
+ return dequantized
446
+
447
+ cur_shape = self .shape
448
+
449
+ if len (cur_shape )== 4 :
450
+ dequantized = dequant_4d (self )
451
+ else :
452
+ assert len (cur_shape ) == 5
453
+ num_experts = cur_shape [0 ]
454
+ dequantized_list = []
455
+ import fbvscode ; fbvscode .set_trace ()
456
+ for expert in range (num_experts ):
457
+ dequantized_list .append (dequant_4d (self [expert ]).unsqueeze (0 ))
458
+ dequantized = torch .cat (dequantized_list , dim = 0 )
459
+
460
+
402
461
scale , zero = unpack_tinygemm_scales_and_zeros (self .scale_and_zero )
462
+ # TODO: move this to `unpack_tinygemm_scales_and_zeros`?
463
+ scale = scale .reshape (scale .shape [:- 1 ]).contiguous ()
464
+ zero = zero .reshape (zero .shape [:- 1 ]).contiguous ()
403
465
404
- cur_shape = self .shape
405
- assert len (cur_shape ) == 4
406
- inner_k_tiles = cur_shape [- 1 ] * 2
407
- original_shape = (cur_shape [0 ] * 8 , cur_shape [1 ] * (inner_k_tiles * 16 ))
408
- eye_shape = original_shape [1 ]
409
- groupsize = int (original_shape [1 ] / scale .shape [- 2 ])
410
- block_size = (1 , groupsize )
411
466
device = self .device
412
- original_dtype = torch . bfloat16
467
+
413
468
target_dtype = torch .int32
414
469
quant_min = 0
415
470
quant_max = 15
416
471
zero_point_domain = ZeroPointDomain .FLOAT
417
- assert len (block_size ) == 2 and block_size [0 ] == 1
418
- dequantized = torch .ops .aten ._weight_int4pack_mm (
419
- torch .eye (eye_shape , device = device , dtype = original_dtype ),
420
- self .packed_weight ,
421
- groupsize ,
422
- self .scale_and_zero ,
423
- )
424
- dequantized = dequantized .t ().contiguous ()
425
- # TODO: move this to `unpack_tinygemm_scales_and_zeros`?
426
- scale = scale .reshape (scale .shape [:- 1 ]).contiguous ()
427
- zero = zero .reshape (zero .shape [:- 1 ]).contiguous ()
428
472
int_data = quantize_affine (
429
473
dequantized ,
430
- block_size ,
474
+ self . block_size ,
431
475
scale ,
432
476
zero ,
433
477
target_dtype ,
0 commit comments