29
29
)
30
30
from pathlib import Path
31
31
from sentencepiece import SentencePieceProcessor
32
- from model import Transformer
32
+ from model import Transformer , prepare_inputs_for_model
33
33
34
34
35
35
def dynamic_quant (model , example_inputs ):
@@ -139,9 +139,9 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
139
139
@unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
140
140
def test_8da4w_quantizer (self ):
141
141
from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
142
- from torchao .quantization .quant_api import Int8DynActInt4WeightLinear
142
+ from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
143
143
144
- quantizer = Int8DynActInt4WeightQuantizer (group_size = 32 )
144
+ quantizer = Int8DynActInt4WeightQuantizer (groupsize = 32 )
145
145
m = M ().eval ()
146
146
example_inputs = m .example_inputs ()
147
147
m = quantizer .quantize (m )
@@ -151,7 +151,7 @@ def test_8da4w_quantizer(self):
151
151
152
152
@unittest .skip ("skipping until we get checkpoints for gpt-fast" )
153
153
def test_gptq_quantizer (self ):
154
- from torchao .quantization .quant_api import Int8DynActInt4WeightGPTQQuantizer
154
+ from torchao .quantization .GPTQ import Int8DynActInt4WeightGPTQQuantizer , InputRecorder
155
155
# should be similar to TorchCompileDynamicQuantizer
156
156
precision = torch .bfloat16
157
157
device = "cpu"
@@ -169,20 +169,83 @@ def test_gptq_quantizer(self):
169
169
percdamp = 0.01
170
170
groupsize = 128
171
171
calibration_tasks = ["wikitext" ]
172
- calibration_limit = 5
172
+ calibration_limit = 1
173
173
calibration_seq_length = 100
174
+ input_prep_func = prepare_inputs_for_model
174
175
pad_calibration_inputs = False
175
- quantizer = Int8DynActInt4WeightGPTQQuantizer (
176
+
177
+ inputs = InputRecorder (
176
178
tokenizer ,
179
+ calibration_seq_length ,
180
+ input_prep_func ,
181
+ pad_calibration_inputs ,
182
+ model .config .vocab_size ,
183
+ ).record_inputs (
184
+ calibration_tasks ,
185
+ calibration_limit ,
186
+ ).get_inputs ()
187
+
188
+ quantizer = Int8DynActInt4WeightGPTQQuantizer (
177
189
blocksize ,
178
190
percdamp ,
179
191
groupsize ,
180
- calibration_tasks ,
181
- calibration_limit ,
192
+ )
193
+ model .setup_caches (max_batch_size = 1 , max_seq_length = calibration_seq_length )
194
+ model = quantizer .quantize (model , inputs )
195
+ compiled = torch .compile (model , mode = "max-autotune" )
196
+ with torch .no_grad ():
197
+ compiled (inputs [0 ].values [0 ], inputs [1 ].values [0 ])
198
+
199
+ @unittest .skip ("skipping until we get checkpoints for gpt-fast" )
200
+ def test_gptq_quantizer_gpt_fast (self ):
201
+ from torchao .quantization .GPTQ import Int8DynActInt4WeightGPTQQuantizer , InputRecorder
202
+ # should be similar to TorchCompileDynamicQuantizer
203
+ precision = torch .bfloat16
204
+ device = "cuda"
205
+ checkpoint_path = Path ("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" )
206
+ model = Transformer .from_name (checkpoint_path .parent .name )
207
+ checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
208
+ model .load_state_dict (checkpoint , assign = True )
209
+ model = model .to (dtype = precision , device = device )
210
+ tokenizer_path = checkpoint_path .parent / "tokenizer.model"
211
+ assert tokenizer_path .is_file (), tokenizer_path
212
+ tokenizer = SentencePieceProcessor ( # pyre-ignore[28]
213
+ model_file = str (tokenizer_path )
214
+ )
215
+ blocksize = 128
216
+ percdamp = 0.01
217
+ groupsize = 128
218
+ calibration_tasks = ["wikitext" ]
219
+ calibration_limit = 1
220
+ calibration_seq_length = 100
221
+ input_prep_func = prepare_inputs_for_model
222
+ pad_calibration_inputs = False
223
+
224
+ inputs = InputRecorder (
225
+ tokenizer ,
182
226
calibration_seq_length ,
227
+ input_prep_func ,
183
228
pad_calibration_inputs ,
229
+ model .config .vocab_size ,
230
+ ).record_inputs (
231
+ calibration_tasks ,
232
+ calibration_limit ,
233
+ ).get_inputs ()
234
+
235
+ quantizer = Int8DynActInt4WeightGPTQQuantizer (
236
+ blocksize ,
237
+ percdamp ,
238
+ groupsize ,
239
+ _is_gpt_fast = True ,
240
+ _use_cuda = True ,
184
241
)
185
- model = quantizer .quantize (model )
242
+
243
+ model .setup_caches (max_batch_size = 1 , max_seq_length = calibration_seq_length )
244
+
245
+ model = quantizer .quantize (model , inputs )
246
+ compiled = torch .compile (model , mode = "max-autotune" )
247
+ with torch .no_grad ():
248
+ compiled (inputs [0 ].values [0 ], inputs [1 ].values [0 ])
186
249
187
250
if __name__ == "__main__" :
188
251
unittest .main ()
0 commit comments