Skip to content

Commit 21a2d29

Browse files
xiaowangintelpytorchmergebot
authored andcommitted
Enable Int4WeightOnlyGPTQQuantizer on Intel GPU. (#2200)
Following pytorch/pytorch#153019 requests, we enable int4wo-GPTQ for Intel GPU in pytorch/ao after RTN ready. How to run int4wo-GPTQ on Intel GPU: ```markdown from pathlib import Path import torch from torchao._models._eval import ( LMEvalInputRecorder, TransformerEvalWrapper, ) from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer from torchao._models.llama.model import Transformer, prepare_inputs_for_model from torchao._models.llama.tokenizer import get_tokenizer from torchao.dtypes import ( Int4XPULayout, ) precision = torch.bfloat16 device = "xpu" checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) model = model.to(dtype=precision, device="cpu") model.eval() tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), tokenizer_path tokenizer = get_tokenizer( tokenizer_path, "Llama-2-7b-chat-hf", ) groupsize = 64 blocksize = 128 percdamp = 0.01 calibration_tasks = ["wikitext"] calibration_limit = 1 calibration_seq_length = 100 input_prep_func = prepare_inputs_for_model pad_calibration_inputs = False inputs = ( LMEvalInputRecorder( tokenizer, calibration_seq_length, input_prep_func, model.config.vocab_size, pad_calibration_inputs, device="cpu", ) .record_inputs( calibration_tasks, calibration_limit, ) .get_recorded_inputs() ) quantizer = Int4WeightOnlyGPTQQuantizer( groupsize, blocksize, percdamp, device=torch.device(device), layout=Int4XPULayout(), ) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, *inputs).xpu() model.reset_caches() limit = 1 result = TransformerEvalWrapper( model.xpu(), tokenizer, model.config.block_size, prepare_inputs_for_model, device, ).run_eval( ["wikitext"], limit, ) Pull Request resolved: #2200 Approved by: https://github.com/liangan1, https://github.com/jerryzh168
1 parent e4f2715 commit 21a2d29

File tree

2 files changed

+69
-28
lines changed

2 files changed

+69
-28
lines changed

torchao/_models/_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
self, model, tokenizer, max_seq_length, input_prep_func=None, device="cuda"
3636
):
3737
try:
38-
super().__init__()
38+
super().__init__(device=device)
3939
except TypeError:
4040
# lm_eval 0.4.2 removed the default init
4141
super().__init__("gpt2", device="cpu")

torchao/quantization/GPTQ/GPTQ.py

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
import torch.nn as nn
1111
from torch.utils._pytree import tree_flatten, tree_unflatten
1212

13-
from torchao.dtypes import TensorCoreTiledLayout, to_affine_quantized_intx_static
13+
from torchao.dtypes import (
14+
Layout,
15+
TensorCoreTiledLayout,
16+
to_affine_quantized_intx_static,
17+
)
1418
from torchao.quantization.quant_primitives import (
1519
ZeroPointDomain,
1620
)
@@ -131,6 +135,7 @@ def configure_quantization_mode(
131135
group_size=-1,
132136
percdamp=0.01,
133137
blocksize=128,
138+
device: torch.device = torch.device("cuda"),
134139
):
135140
cls.get_qparams_func = get_qparams_func
136141
cls.quantize_func = quantize_func
@@ -144,6 +149,7 @@ def configure_quantization_mode(
144149
cls.group_size = group_size
145150
cls.percdamp = percdamp
146151
cls.blocksize = blocksize
152+
cls.device = device
147153

148154
@classmethod
149155
def __torch_function__(
@@ -178,6 +184,10 @@ def __torch_function__(
178184
# then we can do the fast thing.
179185

180186
quantize_linear = not skip_gptq and cls.is_linear_layer(func)
187+
if hasattr(cls, "device") and isinstance(cls.device, torch.device):
188+
device = cls.device
189+
else:
190+
device = "cpu"
181191
# Determine if function is in-place
182192

183193
# initialize function tracking
@@ -199,7 +209,7 @@ def __torch_function__(
199209

200210
# if we're not doing an in place op, move singular tensors to cuda now
201211
if not is_in_place:
202-
flat_args = _tensors_to_cuda(flat_args)
212+
flat_args = _tensors_to_device(flat_args, device=device)
203213

204214
# convert [A, MultiTensor(b), MultiTensor(c1,c2,c3)] => [[A,b,c1], [A,b,c2] [A,b,c3]]
205215
# if its in place then instead we first pad i.e. MultiTensor(b) => MultiTensor(b1, b2, b3)
@@ -208,7 +218,9 @@ def __torch_function__(
208218

209219
with torch._C.DisableTorchFunctionSubclass():
210220
if not quantize_linear: # normal function eval
211-
out = cls._evaluate_function(func, grouped_args, spec, is_in_place)
221+
out = cls._evaluate_function(
222+
func, grouped_args, spec, is_in_place, device
223+
)
212224

213225
# go back and unpad everything where possible.
214226
if not GPTQ_FUNC_LIST[func]["is_in_place"]:
@@ -217,15 +229,15 @@ def __torch_function__(
217229

218230
# GPTQ quantization for linear layers
219231
# Calculate Hessian approximation
220-
H = _calculate_hessian(grouped_args, spec)
232+
H = _calculate_hessian(grouped_args, spec, device)
221233

222234
# turn weight MultiTensor into single cuda tensor
223235
W = args[1]
224236
if isinstance(W, MultiTensor):
225237
W = W.values[0]
226238
W = W.to(H.device)
227239

228-
Q, DQ, all_qparams = cls.faster_quant(H, W.detach())
240+
Q, DQ, all_qparams = cls.faster_quant(H, W.detach(), device)
229241

230242
# make quantized tensor subclass
231243
qtensor = cls.make_qtensor(Q, all_qparams)
@@ -244,8 +256,8 @@ def __torch_function__(
244256
_do_unpad(flat_args, orig_counts=orig_counts)
245257
return out
246258
if args[0].debug:
247-
act = args[0].values[0].to("cuda")
248-
bias = args[2].values[0].to("cuda") if args[2] is not None else args[2]
259+
act = args[0].values[0].to(device)
260+
bias = args[2].values[0].to(device) if args[2] is not None else args[2]
249261

250262
new_out = out.values[0].cpu()
251263
old_out = (
@@ -265,7 +277,7 @@ def __torch_function__(
265277
"SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after)
266278
) # matches
267279
print(
268-
"SQNR for weight (can be low)", SQNR(W, DQ.cuda())
280+
"SQNR for weight (can be low)", SQNR(W, DQ.to(device))
269281
) # fine to not match
270282
print(
271283
"SQNR for output with GPTQ (hopefully 35+)",
@@ -318,14 +330,14 @@ def grouped_to_flat(cls, grouped: List[Tuple[Any, ...]]) -> Tuple[List[Any], boo
318330
return flattened, non_tensors_equal
319331

320332
@classmethod
321-
def _evaluate_function(cls, func, grouped_args, spec, is_in_place):
333+
def _evaluate_function(cls, func, grouped_args, spec, is_in_place, device):
322334
outputs = []
323335
for inp in grouped_args:
324336
# we move all remaining cpu tensors to cuda
325-
cuda_inp = _tensors_to_cuda(inp)
337+
device_inp = _tensors_to_device(inp, device)
326338

327339
# return input to original structure
328-
cur_args, cur_kwargs = tree_unflatten(cuda_inp, spec)
340+
cur_args, cur_kwargs = tree_unflatten(device_inp, spec)
329341

330342
out = func(*cur_args, **cur_kwargs)
331343

@@ -336,7 +348,7 @@ def _evaluate_function(cls, func, grouped_args, spec, is_in_place):
336348
# categortize func as in place.
337349
if is_in_place:
338350
detected_mutation = _maybe_copy_new_values(
339-
inp, cuda_inp, force=GPTQ_FUNC_LIST[func]["is_in_place"]
351+
inp, device_inp, force=GPTQ_FUNC_LIST[func]["is_in_place"]
340352
) # if we already know its in place, don't compare, just copy
341353
if detected_mutation and GPTQ_FUNC_LIST[func]["is_in_place"] is None:
342354
GPTQ_FUNC_LIST[func]["is_in_place"] = True
@@ -365,13 +377,14 @@ def _evaluate_function(cls, func, grouped_args, spec, is_in_place):
365377
return final_out
366378

367379
@classmethod
368-
def faster_quant(cls, H, W):
380+
def faster_quant(cls, H, W, device):
369381
"""
370382
GPTQ quantization implementation.
371383
372384
Args:
373385
H: Hessian matrix approximation
374386
W: Weight matrix to quantize
387+
device: accelerator device
375388
376389
Returns:
377390
Tuple containing:
@@ -457,7 +470,12 @@ def faster_quant(cls, H, W):
457470
Hinv[block_start:block_end, block_end:]
458471
)
459472

460-
torch.cuda.synchronize()
473+
if "xpu" in device.type:
474+
torch.xpu.synchronize()
475+
elif "cuda" in device.type:
476+
torch.cuda.synchronize()
477+
else:
478+
pass
461479

462480
if all_qparams == []:
463481
all_qparams.append(cur_qparams)
@@ -571,6 +589,7 @@ def __init__(self):
571589
self.make_qtensor = None
572590
self.skip_layer_func = None
573591
self.act_fake_quant_func = None
592+
self.device = None
574593

575594
def _check_functions(self):
576595
assert self.get_qparams_func is not None, "get_qparams_func must be set"
@@ -611,6 +630,7 @@ def _create_quantized_state_dict(
611630
group_size=group_size,
612631
percdamp=percdamp,
613632
blocksize=blocksize,
633+
device=self.device,
614634
)
615635
# Set the state dict for the original model
616636
self.state_dict_manager.set_state_dict(model)
@@ -639,6 +659,7 @@ def __init__(
639659
inner_k_tiles=8,
640660
padding_allowed=True,
641661
device: torch.device = torch.device("cuda"),
662+
layout: Optional[Layout] = TensorCoreTiledLayout(inner_k_tiles=8),
642663
):
643664
super().__init__()
644665
self.group_size = group_size
@@ -647,14 +668,31 @@ def __init__(
647668
self.inner_k_tiles = inner_k_tiles
648669
self.padding_allowed = padding_allowed
649670
self.device = device
671+
self.device = self.device
650672
self.act_fake_quant_func = None
673+
self.layout = layout
651674
n_bit = 4
675+
676+
if "xpu" in self.device.type:
677+
self.zero_point_domain = ZeroPointDomain.INT
678+
self.zeros_precision = torch.int8
679+
else:
680+
self.zero_point_domain = ZeroPointDomain.FLOAT
681+
652682
self.get_qparams_func = lambda w: get_groupwise_affine_qparams(
653-
w, n_bit, group_size
683+
w,
684+
n_bit,
685+
group_size,
686+
zero_point_domain=self.zero_point_domain,
654687
)
655688
self.quantize_func = (
656689
lambda w, qparams: groupwise_affine_quantize_tensor_from_qparams(
657-
w, qparams[0], qparams[1], n_bit, group_size
690+
w,
691+
qparams[0],
692+
qparams[1],
693+
n_bit,
694+
group_size,
695+
zero_point_domain=self.zero_point_domain,
658696
)
659697
)
660698
self.dequantize_func = (
@@ -664,6 +702,7 @@ def __init__(
664702
qparams[1],
665703
n_bit,
666704
group_size,
705+
zero_point_domain=self.zero_point_domain,
667706
)
668707
)
669708
self.combine_qparams_list_func = lambda qparams_list: [
@@ -681,15 +720,15 @@ def make_qtensor(q, qparams):
681720
weight = self.dequantize_func(q, qparams)
682721
scale = qparams[0]
683722
zero_point = qparams[1]
723+
if self.zero_point_domain == ZeroPointDomain.INT:
724+
zero_point = zero_point.to(self.zeros_precision)
684725

685726
# copied from quant_api apply_int4_weight_only_quant (this should probably be made into a utility fn at some point)
686727
# mapping_type = MappingType.ASYMMETRIC
687728
block_size = (1, group_size)
688729
target_dtype = torch.int32
689730
quant_min = 0
690731
quant_max = 15
691-
zero_point_domain = ZeroPointDomain.FLOAT
692-
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
693732
# at least the big up to here should be a util
694733

695734
quantized_tensor = to_affine_quantized_intx_static(
@@ -700,8 +739,8 @@ def make_qtensor(q, qparams):
700739
target_dtype=target_dtype,
701740
quant_min=quant_min,
702741
quant_max=quant_max,
703-
zero_point_domain=zero_point_domain,
704-
_layout=_layout,
742+
zero_point_domain=self.zero_point_domain,
743+
_layout=self.layout,
705744
)
706745
return quantized_tensor
707746

@@ -829,12 +868,13 @@ def _flat_to_grouped_and_pad(
829868
return grouped, orig_counts
830869

831870

832-
def _tensors_to_cuda(args, move_all=False):
871+
def _tensors_to_device(args, device=torch.device("cuda"), move_all=False):
833872
"""
834-
Move tensors to CUDA for faster processing.
873+
Move tensors to accelerator for faster processing.
835874
836875
Args:
837876
args: Arguments that may contain tensors
877+
device: accelerator device
838878
move_all: Whether to move all tensors or just single count tensors
839879
840880
Returns:
@@ -843,10 +883,10 @@ def _tensors_to_cuda(args, move_all=False):
843883
new_args = []
844884
for x in args:
845885
if isinstance(x, MultiTensor) and (x.count == 1 or move_all):
846-
new_args.append(x.__class__(x.values[0].cuda()))
886+
new_args.append(x.__class__(x.values[0].to(device)))
847887
else:
848888
new_args.append(
849-
x.cuda()
889+
x.to(device)
850890
if isinstance(x, torch.Tensor) and not isinstance(x, MultiTensor)
851891
else x
852892
)
@@ -888,13 +928,14 @@ def _do_unpad(args, orig_counts):
888928
arg.unpad(count)
889929

890930

891-
def _calculate_hessian(grouped_args, spec):
931+
def _calculate_hessian(grouped_args, spec, device=torch.device("cuda")):
892932
"""
893933
Calculate the Hessian matrix for GPTQ.
894934
895935
Args:
896936
grouped_args: Grouped arguments
897937
spec: Original structure specification
938+
device: accelerator device
898939
899940
Returns:
900941
torch.Tensor: Hessian matrix
@@ -903,10 +944,10 @@ def _calculate_hessian(grouped_args, spec):
903944
total_batches = 0
904945
for inp in grouped_args:
905946
# Move all remaining CPU tensors to CUDA
906-
cuda_inp = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inp]
947+
device_inp = [x.to(device) if isinstance(x, torch.Tensor) else x for x in inp]
907948

908949
# Return input to original structure
909-
cur_args, _ = tree_unflatten(cuda_inp, spec)
950+
cur_args, _ = tree_unflatten(device_inp, spec)
910951

911952
# Setup x (activation tensor)
912953
x = cur_args[0].float()

0 commit comments

Comments
 (0)