@@ -195,15 +195,14 @@ def __init__(
195
195
self ,
196
196
model ,
197
197
weight_config = {},
198
- dataloader = None ,
199
198
nsamples = 128 ,
200
199
use_max_length = True ,
201
200
max_seq_length = 2048 ,
202
201
device = None ,
203
202
export_compressed_model = False ,
204
203
use_layer_wise = False ,
205
204
model_path = "" ,
206
- run_fn = None ,
205
+ dataloader = None ,
207
206
* args ,
208
207
** kwargs ,
209
208
):
@@ -226,7 +225,6 @@ def __init__(
226
225
export_compressed_model (bool, optional): Choose return fp32 or int32 model. Defaults to False.
227
226
use_layer_wise (bool): Enables quantize model per layer. Defaults to False.
228
227
model_path (str): Model path that is used to load state_dict per layer.
229
- run_fn: a function to run model inference for collecting input information.
230
228
device: cpu or cuda
231
229
"""
232
230
# model
@@ -271,9 +269,7 @@ def __init__(
271
269
self .dataloader_original = dataloader
272
270
self .dataloader = []
273
271
self .nsamples = nsamples
274
- self .run_fn = run_fn
275
- self .run_args = kwargs .get ("run_args" , None )
276
- if run_fn is None :
272
+ if dataloader is not None :
277
273
self .prepare_dataloader ()
278
274
279
275
def prepare_dataloader (self ):
@@ -489,7 +485,7 @@ def track_hidden_states(self, data):
489
485
return data [0 ]
490
486
491
487
@torch .no_grad ()
492
- def pre_quantization (self ):
488
+ def prepare_for_calibration (self ):
493
489
"""Prepare input calibration data and other attributes which are critical for gptq execution."""
494
490
try :
495
491
self .cache_key_arguments = {
@@ -532,34 +528,13 @@ def forward(layer, *args, **kwargs):
532
528
# Step2: modify the first transformer block's forward function to obtain inputs for calibration
533
529
if not self .use_layer_wise :
534
530
self .gptq_related_blocks ["transformers" ][0 ] = self .gptq_related_blocks ["transformers" ][0 ].to (self .device )
535
- forward_cache = self .gptq_related_blocks ["transformers" ][0 ].forward
531
+ self . forward_cache = self .gptq_related_blocks ["transformers" ][0 ].forward
536
532
self .gptq_related_blocks ["transformers" ][0 ].forward = partial (
537
533
forward , self .gptq_related_blocks ["transformers" ][0 ]
538
534
)
539
535
540
- # Step3: run forward to obtain calibration datasets
541
- logger .info ("Collecting calibration inputs..." )
542
- logger .info ("Collecting calibration inputs by running the run_fn provided by user." )
543
- if self .run_fn :
544
- if self .run_args :
545
- self .run_fn (self .model , * self .run_args )
546
- accelerator .mark_step ()
547
- else :
548
- self .run_fn (self .model )
549
- accelerator .mark_step ()
550
- else :
551
- for batch in tqdm (self .dataloader ):
552
- if not self .use_layer_wise :
553
- batch = move_input_to_device (batch , self .device )
554
- try :
555
- if isinstance (batch , tuple ) or isinstance (batch , list ):
556
- self .model (batch [0 ])
557
- elif isinstance (batch , dict ):
558
- self .model (** batch )
559
- else :
560
- self .model (batch )
561
- except ValueError :
562
- pass
536
+ @torch .no_grad ()
537
+ def remove_prepare_for_calibration (self ):
563
538
# output inp data shape
564
539
logger .info ("All calibration data's shape =>" )
565
540
# check all hidden_states shape
@@ -571,7 +546,7 @@ def forward(layer, *args, **kwargs):
571
546
logger .info ("Done." )
572
547
573
548
# Step 4: restore original forward function, relocate layers back to cpu.
574
- self .gptq_related_blocks ["transformers" ][0 ].forward = forward_cache
549
+ self .gptq_related_blocks ["transformers" ][0 ].forward = self . forward_cache
575
550
if not self .use_layer_wise :
576
551
self .gptq_related_blocks ["transformers" ][0 ] = self .gptq_related_blocks ["transformers" ][0 ].cpu ()
577
552
for embedding_name , embedding_layer in self .gptq_related_blocks ["embeddings" ].items ():
@@ -606,7 +581,6 @@ def execute_quantization(self, means=None, stds=None):
606
581
# Step1: prepare quantization (calibration datasets)
607
582
608
583
logger .info ("Begin ====>" )
609
- self .pre_quantization ()
610
584
model_path = self .model_path
611
585
612
586
# Step2: run gptq quantization in a transformer block-wise manner.
@@ -1144,41 +1118,58 @@ def ready(self):
1144
1118
return torch .all (self .scale != 0 )
1145
1119
1146
1120
1147
- def gptq_quantize (
1148
- model ,
1149
- weight_config = {},
1150
- dataloader = None ,
1151
- nsamples = 128 ,
1152
- max_seq_length = 2048 ,
1153
- use_max_length = True ,
1154
- device = None ,
1155
- export_compressed_model = False ,
1156
- use_layer_wise = False ,
1157
- model_path = None ,
1158
- run_fn = None ,
1159
- run_args = None ,
1160
- ):
1161
- """Run weight-only quantization with."""
1162
- # TODO: unify weight_config keys, add docstring, and support default config
1163
- assert isinstance (model , torch .nn .Module ), "only support torch module"
1164
- if use_layer_wise :
1165
- assert model_path is not None , "model_path should not be None when use layer wise mode"
1166
- from .gptq import GPTQuantizer
1167
-
1168
- gptq_quantizer = GPTQuantizer (
1121
+ from neural_compressor .torch .algorithms import Quantizer as INCQuantizer
1122
+
1123
+
1124
+ class INCGPTQQuantizer (INCQuantizer ):
1125
+ def __init__ (self , quant_config = {}):
1126
+ """Init a RTNQuantizer object.
1127
+
1128
+ Args:
1129
+ quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
1130
+ """
1131
+ super ().__init__ (quant_config )
1132
+
1133
+ @torch .no_grad ()
1134
+ def prepare (
1135
+ self ,
1169
1136
model ,
1170
- weight_config ,
1171
- dataloader ,
1172
- nsamples ,
1173
- use_max_length ,
1174
- max_seq_length ,
1175
- device ,
1176
- export_compressed_model = export_compressed_model ,
1177
- use_layer_wise = use_layer_wise ,
1178
- model_path = model_path ,
1179
- run_fn = run_fn ,
1180
- run_args = run_args ,
1181
- )
1182
- fp32_modified_model , gptq_config = gptq_quantizer .execute_quantization ()
1183
- logger .info ("GPTQ quantizing done." )
1184
- return fp32_modified_model , gptq_config
1137
+ nsamples = 128 ,
1138
+ max_seq_length = 2048 ,
1139
+ use_max_length = True ,
1140
+ device = None ,
1141
+ export_compressed_model = False ,
1142
+ use_layer_wise = False ,
1143
+ model_path = None ,
1144
+ * args ,
1145
+ ** kwargs ,
1146
+ ):
1147
+ """Run weight-only quantization with."""
1148
+ # TODO: unify weight_config keys, add docstring, and support default config
1149
+ assert isinstance (model , torch .nn .Module ), "only support torch module"
1150
+ if use_layer_wise :
1151
+ assert model_path is not None , "model_path should not be None when use layer wise mode"
1152
+ from .gptq import GPTQuantizer
1153
+
1154
+ self .gptq_quantizer = GPTQuantizer (
1155
+ model ,
1156
+ weight_config = self .quant_config ,
1157
+ nsamples = nsamples ,
1158
+ use_max_length = use_max_length ,
1159
+ max_seq_length = max_seq_length ,
1160
+ device = device ,
1161
+ export_compressed_model = export_compressed_model ,
1162
+ use_layer_wise = use_layer_wise ,
1163
+ model_path = model_path ,
1164
+ )
1165
+ self .gptq_quantizer .prepare_for_calibration ()
1166
+ return self .gptq_quantizer .model
1167
+
1168
+ @torch .no_grad ()
1169
+ def convert (self , model , * args , ** kwargs ):
1170
+ self .gptq_quantizer .model = model
1171
+ self .gptq_quantizer .remove_prepare_for_calibration ()
1172
+ q_model , gptq_config = self .gptq_quantizer .execute_quantization ()
1173
+ q_model .gptq_config = gptq_config
1174
+ logger .info ("GPTQ quantizing done." )
1175
+ return q_model
0 commit comments