10
10
import torch .nn as nn
11
11
from torch .utils ._pytree import tree_flatten , tree_unflatten
12
12
13
- from torchao .dtypes import TensorCoreTiledLayout , to_affine_quantized_intx_static
13
+ from torchao .dtypes import (
14
+ Layout ,
15
+ TensorCoreTiledLayout ,
16
+ to_affine_quantized_intx_static ,
17
+ )
14
18
from torchao .quantization .quant_primitives import (
15
19
ZeroPointDomain ,
16
20
)
@@ -131,6 +135,7 @@ def configure_quantization_mode(
131
135
group_size = - 1 ,
132
136
percdamp = 0.01 ,
133
137
blocksize = 128 ,
138
+ device : torch .device = torch .device ("cuda" ),
134
139
):
135
140
cls .get_qparams_func = get_qparams_func
136
141
cls .quantize_func = quantize_func
@@ -144,6 +149,7 @@ def configure_quantization_mode(
144
149
cls .group_size = group_size
145
150
cls .percdamp = percdamp
146
151
cls .blocksize = blocksize
152
+ cls .device = device
147
153
148
154
@classmethod
149
155
def __torch_function__ (
@@ -178,6 +184,10 @@ def __torch_function__(
178
184
# then we can do the fast thing.
179
185
180
186
quantize_linear = not skip_gptq and cls .is_linear_layer (func )
187
+ if hasattr (cls , "device" ) and isinstance (cls .device , torch .device ):
188
+ device = cls .device
189
+ else :
190
+ device = "cpu"
181
191
# Determine if function is in-place
182
192
183
193
# initialize function tracking
@@ -199,7 +209,7 @@ def __torch_function__(
199
209
200
210
# if we're not doing an in place op, move singular tensors to cuda now
201
211
if not is_in_place :
202
- flat_args = _tensors_to_cuda (flat_args )
212
+ flat_args = _tensors_to_device (flat_args , device = device )
203
213
204
214
# convert [A, MultiTensor(b), MultiTensor(c1,c2,c3)] => [[A,b,c1], [A,b,c2] [A,b,c3]]
205
215
# if its in place then instead we first pad i.e. MultiTensor(b) => MultiTensor(b1, b2, b3)
@@ -208,7 +218,9 @@ def __torch_function__(
208
218
209
219
with torch ._C .DisableTorchFunctionSubclass ():
210
220
if not quantize_linear : # normal function eval
211
- out = cls ._evaluate_function (func , grouped_args , spec , is_in_place )
221
+ out = cls ._evaluate_function (
222
+ func , grouped_args , spec , is_in_place , device
223
+ )
212
224
213
225
# go back and unpad everything where possible.
214
226
if not GPTQ_FUNC_LIST [func ]["is_in_place" ]:
@@ -217,15 +229,15 @@ def __torch_function__(
217
229
218
230
# GPTQ quantization for linear layers
219
231
# Calculate Hessian approximation
220
- H = _calculate_hessian (grouped_args , spec )
232
+ H = _calculate_hessian (grouped_args , spec , device )
221
233
222
234
# turn weight MultiTensor into single cuda tensor
223
235
W = args [1 ]
224
236
if isinstance (W , MultiTensor ):
225
237
W = W .values [0 ]
226
238
W = W .to (H .device )
227
239
228
- Q , DQ , all_qparams = cls .faster_quant (H , W .detach ())
240
+ Q , DQ , all_qparams = cls .faster_quant (H , W .detach (), device )
229
241
230
242
# make quantized tensor subclass
231
243
qtensor = cls .make_qtensor (Q , all_qparams )
@@ -244,8 +256,8 @@ def __torch_function__(
244
256
_do_unpad (flat_args , orig_counts = orig_counts )
245
257
return out
246
258
if args [0 ].debug :
247
- act = args [0 ].values [0 ].to ("cuda" )
248
- bias = args [2 ].values [0 ].to ("cuda" ) if args [2 ] is not None else args [2 ]
259
+ act = args [0 ].values [0 ].to (device )
260
+ bias = args [2 ].values [0 ].to (device ) if args [2 ] is not None else args [2 ]
249
261
250
262
new_out = out .values [0 ].cpu ()
251
263
old_out = (
@@ -265,7 +277,7 @@ def __torch_function__(
265
277
"SQNR for QDQ (this should be inf)" , SQNR (DQ , DQ_after )
266
278
) # matches
267
279
print (
268
- "SQNR for weight (can be low)" , SQNR (W , DQ .cuda ( ))
280
+ "SQNR for weight (can be low)" , SQNR (W , DQ .to ( device ))
269
281
) # fine to not match
270
282
print (
271
283
"SQNR for output with GPTQ (hopefully 35+)" ,
@@ -318,14 +330,14 @@ def grouped_to_flat(cls, grouped: List[Tuple[Any, ...]]) -> Tuple[List[Any], boo
318
330
return flattened , non_tensors_equal
319
331
320
332
@classmethod
321
- def _evaluate_function (cls , func , grouped_args , spec , is_in_place ):
333
+ def _evaluate_function (cls , func , grouped_args , spec , is_in_place , device ):
322
334
outputs = []
323
335
for inp in grouped_args :
324
336
# we move all remaining cpu tensors to cuda
325
- cuda_inp = _tensors_to_cuda (inp )
337
+ device_inp = _tensors_to_device (inp , device )
326
338
327
339
# return input to original structure
328
- cur_args , cur_kwargs = tree_unflatten (cuda_inp , spec )
340
+ cur_args , cur_kwargs = tree_unflatten (device_inp , spec )
329
341
330
342
out = func (* cur_args , ** cur_kwargs )
331
343
@@ -336,7 +348,7 @@ def _evaluate_function(cls, func, grouped_args, spec, is_in_place):
336
348
# categortize func as in place.
337
349
if is_in_place :
338
350
detected_mutation = _maybe_copy_new_values (
339
- inp , cuda_inp , force = GPTQ_FUNC_LIST [func ]["is_in_place" ]
351
+ inp , device_inp , force = GPTQ_FUNC_LIST [func ]["is_in_place" ]
340
352
) # if we already know its in place, don't compare, just copy
341
353
if detected_mutation and GPTQ_FUNC_LIST [func ]["is_in_place" ] is None :
342
354
GPTQ_FUNC_LIST [func ]["is_in_place" ] = True
@@ -365,13 +377,14 @@ def _evaluate_function(cls, func, grouped_args, spec, is_in_place):
365
377
return final_out
366
378
367
379
@classmethod
368
- def faster_quant (cls , H , W ):
380
+ def faster_quant (cls , H , W , device ):
369
381
"""
370
382
GPTQ quantization implementation.
371
383
372
384
Args:
373
385
H: Hessian matrix approximation
374
386
W: Weight matrix to quantize
387
+ device: accelerator device
375
388
376
389
Returns:
377
390
Tuple containing:
@@ -457,7 +470,12 @@ def faster_quant(cls, H, W):
457
470
Hinv [block_start :block_end , block_end :]
458
471
)
459
472
460
- torch .cuda .synchronize ()
473
+ if "xpu" in device .type :
474
+ torch .xpu .synchronize ()
475
+ elif "cuda" in device .type :
476
+ torch .cuda .synchronize ()
477
+ else :
478
+ pass
461
479
462
480
if all_qparams == []:
463
481
all_qparams .append (cur_qparams )
@@ -571,6 +589,7 @@ def __init__(self):
571
589
self .make_qtensor = None
572
590
self .skip_layer_func = None
573
591
self .act_fake_quant_func = None
592
+ self .device = None
574
593
575
594
def _check_functions (self ):
576
595
assert self .get_qparams_func is not None , "get_qparams_func must be set"
@@ -611,6 +630,7 @@ def _create_quantized_state_dict(
611
630
group_size = group_size ,
612
631
percdamp = percdamp ,
613
632
blocksize = blocksize ,
633
+ device = self .device ,
614
634
)
615
635
# Set the state dict for the original model
616
636
self .state_dict_manager .set_state_dict (model )
@@ -639,6 +659,7 @@ def __init__(
639
659
inner_k_tiles = 8 ,
640
660
padding_allowed = True ,
641
661
device : torch .device = torch .device ("cuda" ),
662
+ layout : Optional [Layout ] = TensorCoreTiledLayout (inner_k_tiles = 8 ),
642
663
):
643
664
super ().__init__ ()
644
665
self .group_size = group_size
@@ -647,14 +668,31 @@ def __init__(
647
668
self .inner_k_tiles = inner_k_tiles
648
669
self .padding_allowed = padding_allowed
649
670
self .device = device
671
+ self .device = self .device
650
672
self .act_fake_quant_func = None
673
+ self .layout = layout
651
674
n_bit = 4
675
+
676
+ if "xpu" in self .device .type :
677
+ self .zero_point_domain = ZeroPointDomain .INT
678
+ self .zeros_precision = torch .int8
679
+ else :
680
+ self .zero_point_domain = ZeroPointDomain .FLOAT
681
+
652
682
self .get_qparams_func = lambda w : get_groupwise_affine_qparams (
653
- w , n_bit , group_size
683
+ w ,
684
+ n_bit ,
685
+ group_size ,
686
+ zero_point_domain = self .zero_point_domain ,
654
687
)
655
688
self .quantize_func = (
656
689
lambda w , qparams : groupwise_affine_quantize_tensor_from_qparams (
657
- w , qparams [0 ], qparams [1 ], n_bit , group_size
690
+ w ,
691
+ qparams [0 ],
692
+ qparams [1 ],
693
+ n_bit ,
694
+ group_size ,
695
+ zero_point_domain = self .zero_point_domain ,
658
696
)
659
697
)
660
698
self .dequantize_func = (
@@ -664,6 +702,7 @@ def __init__(
664
702
qparams [1 ],
665
703
n_bit ,
666
704
group_size ,
705
+ zero_point_domain = self .zero_point_domain ,
667
706
)
668
707
)
669
708
self .combine_qparams_list_func = lambda qparams_list : [
@@ -681,15 +720,15 @@ def make_qtensor(q, qparams):
681
720
weight = self .dequantize_func (q , qparams )
682
721
scale = qparams [0 ]
683
722
zero_point = qparams [1 ]
723
+ if self .zero_point_domain == ZeroPointDomain .INT :
724
+ zero_point = zero_point .to (self .zeros_precision )
684
725
685
726
# copied from quant_api apply_int4_weight_only_quant (this should probably be made into a utility fn at some point)
686
727
# mapping_type = MappingType.ASYMMETRIC
687
728
block_size = (1 , group_size )
688
729
target_dtype = torch .int32
689
730
quant_min = 0
690
731
quant_max = 15
691
- zero_point_domain = ZeroPointDomain .FLOAT
692
- _layout = TensorCoreTiledLayout (inner_k_tiles = 8 )
693
732
# at least the big up to here should be a util
694
733
695
734
quantized_tensor = to_affine_quantized_intx_static (
@@ -700,8 +739,8 @@ def make_qtensor(q, qparams):
700
739
target_dtype = target_dtype ,
701
740
quant_min = quant_min ,
702
741
quant_max = quant_max ,
703
- zero_point_domain = zero_point_domain ,
704
- _layout = _layout ,
742
+ zero_point_domain = self . zero_point_domain ,
743
+ _layout = self . layout ,
705
744
)
706
745
return quantized_tensor
707
746
@@ -829,12 +868,13 @@ def _flat_to_grouped_and_pad(
829
868
return grouped , orig_counts
830
869
831
870
832
- def _tensors_to_cuda (args , move_all = False ):
871
+ def _tensors_to_device (args , device = torch . device ( "cuda" ) , move_all = False ):
833
872
"""
834
- Move tensors to CUDA for faster processing.
873
+ Move tensors to accelerator for faster processing.
835
874
836
875
Args:
837
876
args: Arguments that may contain tensors
877
+ device: accelerator device
838
878
move_all: Whether to move all tensors or just single count tensors
839
879
840
880
Returns:
@@ -843,10 +883,10 @@ def _tensors_to_cuda(args, move_all=False):
843
883
new_args = []
844
884
for x in args :
845
885
if isinstance (x , MultiTensor ) and (x .count == 1 or move_all ):
846
- new_args .append (x .__class__ (x .values [0 ].cuda ( )))
886
+ new_args .append (x .__class__ (x .values [0 ].to ( device )))
847
887
else :
848
888
new_args .append (
849
- x .cuda ( )
889
+ x .to ( device )
850
890
if isinstance (x , torch .Tensor ) and not isinstance (x , MultiTensor )
851
891
else x
852
892
)
@@ -888,13 +928,14 @@ def _do_unpad(args, orig_counts):
888
928
arg .unpad (count )
889
929
890
930
891
- def _calculate_hessian (grouped_args , spec ):
931
+ def _calculate_hessian (grouped_args , spec , device = torch . device ( "cuda" ) ):
892
932
"""
893
933
Calculate the Hessian matrix for GPTQ.
894
934
895
935
Args:
896
936
grouped_args: Grouped arguments
897
937
spec: Original structure specification
938
+ device: accelerator device
898
939
899
940
Returns:
900
941
torch.Tensor: Hessian matrix
@@ -903,10 +944,10 @@ def _calculate_hessian(grouped_args, spec):
903
944
total_batches = 0
904
945
for inp in grouped_args :
905
946
# Move all remaining CPU tensors to CUDA
906
- cuda_inp = [x .cuda ( ) if isinstance (x , torch .Tensor ) else x for x in inp ]
947
+ device_inp = [x .to ( device ) if isinstance (x , torch .Tensor ) else x for x in inp ]
907
948
908
949
# Return input to original structure
909
- cur_args , _ = tree_unflatten (cuda_inp , spec )
950
+ cur_args , _ = tree_unflatten (device_inp , spec )
910
951
911
952
# Setup x (activation tensor)
912
953
x = cur_args [0 ].float ()
0 commit comments