Skip to content

Commit 293ae7b

Browse files
authored
Refactor GPTQ Quantizer, remove lm_eval (#104)
Summary: refactor GPTQ code, remove lm_eval dependency of gptq, remove model dependency of InputRecorder made GPTQ work with gpt-fast. also fixed model so its kv_cache doesn't break gptq. Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d6eb810 Pull Request resolved: #103
1 parent 5420089 commit 293ae7b

File tree

7 files changed

+896
-814
lines changed

7 files changed

+896
-814
lines changed

test/quantization/model.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@
1111
from torch import Tensor
1212
from torch.nn import functional as F
1313

14+
def prepare_inputs_for_model(inps):
15+
# setup inputs in correct format
16+
max_new_tokens = 1
17+
T = inps.size(0)
18+
T_new = T + max_new_tokens
19+
seq = torch.empty(T_new, dtype=inps.dtype, device=inps.device)
20+
seq[:T] = inps
21+
input_pos = torch.arange(0, T, device=inps.device)
22+
x = seq.index_select(0, input_pos).view(1, -1)
23+
return (x, input_pos)
1424

1525
def find_multiple(n: int, k: int) -> int:
1626
if n % k == 0:
@@ -76,10 +86,8 @@ def update(self, input_pos, k_val, v_val):
7686
# input_pos: [S], k_val: [B, H, S, D]
7787
assert input_pos.shape[0] == k_val.shape[2]
7888

79-
k_out = self.k_cache
80-
v_out = self.v_cache
81-
k_out[:, :, input_pos] = k_val
82-
v_out[:, :, input_pos] = v_val
89+
k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
90+
v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
8391

8492
return k_out, v_out
8593

test/quantization/test_quant_api.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030
from pathlib import Path
3131
from sentencepiece import SentencePieceProcessor
32-
from model import Transformer
32+
from model import Transformer, prepare_inputs_for_model
3333

3434

3535
def dynamic_quant(model, example_inputs):
@@ -139,9 +139,9 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
139139
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
140140
def test_8da4w_quantizer(self):
141141
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
142-
from torchao.quantization.quant_api import Int8DynActInt4WeightLinear
142+
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
143143

144-
quantizer = Int8DynActInt4WeightQuantizer(group_size=32)
144+
quantizer = Int8DynActInt4WeightQuantizer(groupsize=32)
145145
m = M().eval()
146146
example_inputs = m.example_inputs()
147147
m = quantizer.quantize(m)
@@ -151,7 +151,7 @@ def test_8da4w_quantizer(self):
151151

152152
@unittest.skip("skipping until we get checkpoints for gpt-fast")
153153
def test_gptq_quantizer(self):
154-
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
154+
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder
155155
# should be similar to TorchCompileDynamicQuantizer
156156
precision = torch.bfloat16
157157
device = "cpu"
@@ -169,20 +169,83 @@ def test_gptq_quantizer(self):
169169
percdamp = 0.01
170170
groupsize = 128
171171
calibration_tasks = ["wikitext"]
172-
calibration_limit = 5
172+
calibration_limit = 1
173173
calibration_seq_length = 100
174+
input_prep_func = prepare_inputs_for_model
174175
pad_calibration_inputs = False
175-
quantizer = Int8DynActInt4WeightGPTQQuantizer(
176+
177+
inputs = InputRecorder(
176178
tokenizer,
179+
calibration_seq_length,
180+
input_prep_func,
181+
pad_calibration_inputs,
182+
model.config.vocab_size,
183+
).record_inputs(
184+
calibration_tasks,
185+
calibration_limit,
186+
).get_inputs()
187+
188+
quantizer = Int8DynActInt4WeightGPTQQuantizer(
177189
blocksize,
178190
percdamp,
179191
groupsize,
180-
calibration_tasks,
181-
calibration_limit,
192+
)
193+
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
194+
model = quantizer.quantize(model, inputs)
195+
compiled = torch.compile(model, mode="max-autotune")
196+
with torch.no_grad():
197+
compiled(inputs[0].values[0], inputs[1].values[0])
198+
199+
@unittest.skip("skipping until we get checkpoints for gpt-fast")
200+
def test_gptq_quantizer_gpt_fast(self):
201+
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder
202+
# should be similar to TorchCompileDynamicQuantizer
203+
precision = torch.bfloat16
204+
device = "cuda"
205+
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
206+
model = Transformer.from_name(checkpoint_path.parent.name)
207+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
208+
model.load_state_dict(checkpoint, assign=True)
209+
model = model.to(dtype=precision, device=device)
210+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
211+
assert tokenizer_path.is_file(), tokenizer_path
212+
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
213+
model_file=str(tokenizer_path)
214+
)
215+
blocksize = 128
216+
percdamp = 0.01
217+
groupsize = 128
218+
calibration_tasks = ["wikitext"]
219+
calibration_limit = 1
220+
calibration_seq_length = 100
221+
input_prep_func = prepare_inputs_for_model
222+
pad_calibration_inputs = False
223+
224+
inputs = InputRecorder(
225+
tokenizer,
182226
calibration_seq_length,
227+
input_prep_func,
183228
pad_calibration_inputs,
229+
model.config.vocab_size,
230+
).record_inputs(
231+
calibration_tasks,
232+
calibration_limit,
233+
).get_inputs()
234+
235+
quantizer = Int8DynActInt4WeightGPTQQuantizer(
236+
blocksize,
237+
percdamp,
238+
groupsize,
239+
_is_gpt_fast=True,
240+
_use_cuda=True,
184241
)
185-
model = quantizer.quantize(model)
242+
243+
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
244+
245+
model = quantizer.quantize(model, inputs)
246+
compiled = torch.compile(model, mode="max-autotune")
247+
with torch.no_grad():
248+
compiled(inputs[0].values[0], inputs[1].values[0])
186249

187250
if __name__ == "__main__":
188251
unittest.main()

0 commit comments

Comments
 (0)