Skip to content

Commit 7b5a097

Browse files
authored
Fix import versions for GPTQ (#105)
1 parent a8704f8 commit 7b5a097

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

test/quantization/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.nn import functional as F
1313

1414
def prepare_inputs_for_model(inps):
15+
inps = inps.squeeze(0)
1516
# setup inputs in correct format
1617
max_new_tokens = 1
1718
T = inps.size(0)

torchao/quantization/GPTQ.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# from model import Transformer # pyre-ignore[21]
2121
from torch.utils._pytree import tree_flatten, tree_unflatten
2222

23-
from .utils import TORCH_VERSION_AFTER_2_4
23+
from .utils import TORCH_VERSION_AFTER_2_3
2424
from typing import Any, Dict, Tuple, Optional
2525
from .unified import Quantizer
2626
from functools import reduce
@@ -89,7 +89,7 @@ def __init__(
8989
# for model
9090
self.input_prep_func = (
9191
input_prep_func if input_prep_func is not None
92-
else lambda x: x
92+
else lambda x: (x,)
9393
)
9494

9595
self.pad_calibration_inputs = pad_calibration_inputs
@@ -180,6 +180,7 @@ def _model_call(self, inps):
180180
else:
181181
inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T))
182182

183+
inps = inps.unsqueeze(0)
183184
model_in = self.input_prep_func(inps)
184185

185186
self.add_input(model_in)
@@ -546,7 +547,7 @@ def faster_quant(self, H, W):
546547
return Q, DQ.to(orig_dtype), all_qparams
547548

548549

549-
if TORCH_VERSION_AFTER_2_4:
550+
if TORCH_VERSION_AFTER_2_3:
550551
from .quant_primitives import (
551552
get_group_qparams_symmetric,
552553
group_quantize_tensor_symmetric,

torchao/quantization/quant_api.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,14 @@
4646
from .GPTQ import (
4747
Int8DynActInt4WeightQuantizer,
4848
Int8DynActInt4WeightGPTQQuantizer,
49+
Int4WeightQuantizer,
50+
Int4WeightGPTQQuantizer,
4951
)
5052
__all__ += [
5153
"Int8DynActInt4WeightQuantizer",
5254
"Int8DynActInt4WeightGPTQQuantizer",
55+
"Int4WeightQuantizer",
56+
"Int4WeightGPTQQuantizer",
5357
]
5458

5559

@@ -196,6 +200,3 @@ def replace_conv2d_1x1(conv):
196200
_replace_with_custom_fn_if_matches_filter(
197201
model, replace_conv2d_1x1, filter_fn=filter_fn
198202
)
199-
200-
if TORCH_VERSION_AFTER_2_3:
201-
from .GPTQ import Int8DynActInt4WeightQuantizer, Int8DynActInt4WeightGPTQQuantizer

0 commit comments

Comments
 (0)