Skip to content

Commit d209c9a

Browse files
committed
enhance base Quantizer
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
1 parent 76fda52 commit d209c9a

File tree

3 files changed

+9
-38
lines changed

3 files changed

+9
-38
lines changed

neural_compressor/torch/algorithms/base_algorithm.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,19 +85,18 @@ def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
8585
Returns:
8686
A quantized model.
8787
"""
88+
model = self.prepare(model, *args, **kwargs)
89+
8890
run_fn = kwargs.get("run_fn", None)
89-
run_args = kwargs.get("run_args", None)
90-
assert run_fn is not None, (
91-
"Can't find run_func. Please provide run_func to quantize API "
92-
"or overwrite quantize member function in your Quantizer class."
93-
)
91+
if run_fn is not None:
92+
run_args = kwargs.get("run_args", None)
93+
if run_args:
94+
run_fn(model, *run_args)
95+
else:
96+
run_fn(model)
9497

95-
model = self.prepare(model, *args, **kwargs)
96-
if run_args:
97-
run_fn(model, *run_args)
98-
else:
99-
run_fn(model)
10098
model = self.convert(model, *args, **kwargs)
99+
101100
return model
102101

103102
def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any): # pragma: no cover

neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,20 +114,6 @@ def convert(self, model: torch.nn.Module, *args, **kwargs) -> Optional[torch.nn.
114114
)
115115
return model
116116

117-
@torch.no_grad()
118-
def quantize(self, model: torch.nn.Module, *args, **kwargs):
119-
"""Quantizes a float torch model.
120-
121-
Args:
122-
model: A float model to be quantized.
123-
124-
Returns:
125-
A quantized model.
126-
"""
127-
model = self.prepare(model, *args, **kwargs)
128-
model = self.convert(model, *args, **kwargs)
129-
return model
130-
131117
def save(self, model, path):
132118
# TODO: to implement it in the next PR
133119
pass

neural_compressor/torch/algorithms/weight_only/rtn.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -210,17 +210,3 @@ def convert(
210210
m.weight.t_().contiguous()
211211
m.weight.data.copy_(weight)
212212
return model
213-
214-
@torch.no_grad()
215-
def quantize(self, model, *args, **kwargs):
216-
"""Quantizes a given torch model.
217-
218-
Args:
219-
model: A float model to be quantized.
220-
221-
Returns:
222-
A quantized model.
223-
"""
224-
model = self.prepare(model, *args, **kwargs)
225-
model = self.convert(model, *args, **kwargs)
226-
return model

0 commit comments

Comments
 (0)