Skip to content

Commit 17a33d3

Browse files
committed
enhance attr naming
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent 50b2f5e commit 17a33d3

File tree

1 file changed

+8
-6
lines changed
  • neural_compressor/torch/algorithms/weight_only

1 file changed

+8
-6
lines changed

neural_compressor/torch/algorithms/weight_only/teq.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232

3333

3434
class TrainableEquivalentTransformation:
35-
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input."""
35+
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ)."""
36+
37+
_PREPARE_ATTRS: list[str] = ["weight_config", "trained_alphas"]
38+
_PREPARE_ATTRS_PREFIX = "_prepare_"
3639

3740
def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, example_inputs=None):
3841
"""
@@ -47,7 +50,6 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex
4750
self.device = self._get_device()
4851
self.trained_alphas = {}
4952
self.absorb_to_layer = absorb_to_layer
50-
self._prepared_attrs = ["weight_config", "trained_alphas"]
5153
self._post_initialized = False
5254

5355
def _post_init(self):
@@ -353,13 +355,13 @@ def prepare(self, model, *args, **kwargs):
353355
self._quantizer.model = float_model
354356
logger.info("TEQ quantizing start.")
355357
self._quantizer.add_tuning_scale()
356-
for attr in self._quantizer._prepared_attrs:
357-
setattr(float_model, "_" + attr, getattr(self._quantizer, attr))
358+
for attr in self._quantizer._PREPARE_ATTRS:
359+
setattr(float_model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, getattr(self._quantizer, attr))
358360
return float_model
359361

360362
def convert(self, model, *args: Any, **kwargs: Any):
361-
for attr in self._quantizer._prepared_attrs:
362-
setattr(self._quantizer, attr, getattr(model, "_" + attr, None))
363+
for attr in self._quantizer._PREPARE_ATTRS:
364+
setattr(self._quantizer, attr, getattr(model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, None))
363365
self._quantizer.model = model
364366
self._quantizer.transform()
365367
self._quantizer.quantize()

0 commit comments

Comments
 (0)