Skip to content

Commit 9c9324d

Browse files
ClarkChin08ftian1
authored andcommitted
[fix] fix line too long and ut failure
1 parent 12bf7d9 commit 9c9324d

File tree

5 files changed

+63
-46
lines changed

5 files changed

+63
-46
lines changed

lpot/experimental/benchmark.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,19 @@ def b_dataloader(self, dataloader):
132132
dataloader(generator): user are supported to set a user defined dataloader
133133
which meet the requirements that can yield tuple of
134134
(input, label)/(input, _) batched data.
135-
Another good practice is to use lpot.experimental.common.DataLoader
135+
Another good practice is to use
136+
lpot.experimental.common.DataLoader
136137
to initialize a lpot dataloader object.
137-
Notice lpot.experimental.common.DataLoader is just a wrapper of the
138-
information needed to build a dataloader, it can't yield
138+
Notice lpot.experimental.common.DataLoader
139+
is just a wrapper of the information needed to
140+
build a dataloader, it can't yield
139141
batched data and only in this setter method
140142
a 'real' eval_dataloader will be created,
141143
the reason is we have to know the framework info
142144
and only after the Quantization object created then
143-
framework infomation can be known. Future we will support
144-
creating iterable dataloader from lpot.experimental.common.DataLoader
145+
framework infomation can be known.
146+
Future we will support creating iterable dataloader
147+
from lpot.experimental.common.DataLoader
145148
146149
"""
147150
from .common import _generate_common_dataloader
@@ -157,12 +160,15 @@ def model(self, user_model):
157160
158161
Args:
159162
user_model: user are supported to set model from original framework model format
160-
(eg, tensorflow frozen_pb or path to a saved model), but not recommended.
161-
Best practice is to set from a initialized lpot.experimental.common.Model.
162-
If tensorflow model is used, model's inputs/outputs will be auto inferenced,
163-
but sometimes auto inferenced inputs/outputs will not meet your requests,
164-
set them manually in config yaml file. Another corner case is slim model
165-
of tensorflow, be careful of the name of model configured in yaml file,
163+
(eg, tensorflow frozen_pb or path to a saved model),
164+
but not recommended. Best practice is to set from a initialized
165+
lpot.experimental.common.Model.
166+
If tensorflow model is used, model's inputs/outputs will be
167+
auto inferenced, but sometimes auto inferenced
168+
inputs/outputs will not meet your requests,
169+
set them manually in config yaml file.
170+
Another corner case is slim model of tensorflow,
171+
be careful of the name of model configured in yaml file,
166172
make sure the name is in supported slim model list.
167173
168174
"""
@@ -200,10 +206,11 @@ def metric(self, user_metric):
200206
and user_metric.metric_cls should be sub_class of lpot.metric.BaseMetric.
201207
202208
Args:
203-
user_metric(lpot.experimental.common.Metric): user_metric should be object initialized from
204-
lpot.experimental.common.Metric, in this method the
205-
user_metric.metric_cls will be registered to
206-
specific frameworks and initialized.
209+
user_metric(lpot.experimental.common.Metric):
210+
user_metric should be object initialized from
211+
lpot.experimental.common.Metric, in this method the
212+
user_metric.metric_cls will be registered to
213+
specific frameworks and initialized.
207214
208215
"""
209216
from .common import Metric as LpotMetric
@@ -234,7 +241,8 @@ def postprocess(self, user_postprocess):
234241
235242
Args:
236243
user_postprocess(lpot.experimental.common.Postprocess):
237-
user_postprocess should be object initialized from lpot.experimental.common.Postprocess,
244+
user_postprocess should be object initialized from
245+
lpot.experimental.common.Postprocess,
238246
in this method the user_postprocess.postprocess_cls will be
239247
registered to specific frameworks and initialized.
240248

lpot/experimental/pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def model(self, user_model):
157157
'output_tensor_names': cfg.model.outputs,
158158
'workspace_path': cfg.tuning.workspace.path})
159159

160-
from .model import MODELS
160+
from ..model import MODELS
161161
self._model = MODELS[self.framework](\
162162
user_model.root, framework_model_info, **user_model.kwargs)
163163

lpot/experimental/quantization.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ def calib_dataloader(self, dataloader):
183183
Args:
184184
dataloader(generator): user are supported to set a user defined dataloader
185185
which meet the requirements that can yield tuple of
186-
(input, label)/(input, _) batched data.
187-
Another good practice is to use lpot.experimental.common.DataLoader
188-
to initialize a lpot dataloader object.
189-
Notice lpot.experimental.common.DataLoader is just a wrapper of the
186+
(input, label)/(input, _) batched data. Another good
187+
practice is to use lpot.experimental.common.DataLoader
188+
to initialize a lpot dataloader object. Notice
189+
lpot.experimental.common.DataLoader is just a wrapper of the
190190
information needed to build a dataloader, it can't yield
191191
batched data and only in this setter method
192192
a 'real' calib_dataloader will be created,
@@ -215,16 +215,19 @@ def eval_dataloader(self, dataloader):
215215
dataloader(generator): user are supported to set a user defined dataloader
216216
which meet the requirements that can yield tuple of
217217
(input, label)/(input, _) batched data.
218-
Another good practice is to use lpot.experimental.common.DataLoader
218+
Another good practice is to use
219+
lpot.experimental.common.DataLoader
219220
to initialize a lpot dataloader object.
220-
Notice lpot.experimental.common.DataLoader is just a wrapper of the
221-
information needed to build a dataloader, it can't yield
221+
Notice lpot.experimental.common.DataLoader
222+
is just a wrapper of the information needed to
223+
build a dataloader, it can't yield
222224
batched data and only in this setter method
223225
a 'real' eval_dataloader will be created,
224226
the reason is we have to know the framework info
225227
and only after the Quantization object created then
226-
framework infomation can be known. Future we will support
227-
creating iterable dataloader from lpot.experimental.common.DataLoader
228+
framework infomation can be known.
229+
Future we will support creating iterable dataloader
230+
from lpot.experimental.common.DataLoader
228231
229232
"""
230233
from .common import _generate_common_dataloader
@@ -241,12 +244,15 @@ def model(self, user_model):
241244
242245
Args:
243246
user_model: user are supported to set model from original framework model format
244-
(eg, tensorflow frozen_pb or path to a saved model), but not recommended.
245-
Best practice is to set from a initialized lpot.experimental.common.Model.
246-
If tensorflow model is used, model's inputs/outputs will be auto inferenced,
247-
but sometimes auto inferenced inputs/outputs will not meet your requests,
248-
set them manually in config yaml file. Another corner case is slim model
249-
of tensorflow, be careful of the name of model configured in yaml file,
247+
(eg, tensorflow frozen_pb or path to a saved model),
248+
but not recommended. Best practice is to set from a initialized
249+
lpot.experimental.common.Model.
250+
If tensorflow model is used, model's inputs/outputs will be
251+
auto inferenced, but sometimes auto inferenced
252+
inputs/outputs will not meet your requests,
253+
set them manually in config yaml file.
254+
Another corner case is slim model of tensorflow,
255+
be careful of the name of model configured in yaml file,
250256
make sure the name is in supported slim model list.
251257
252258
"""
@@ -284,10 +290,11 @@ def metric(self, user_metric):
284290
and user_metric.metric_cls should be sub_class of lpot.metric.BaseMetric.
285291
286292
Args:
287-
user_metric(lpot.experimental.common.Metric): user_metric should be object initialized from
288-
lpot.experimental.common.Metric, in this method the
289-
user_metric.metric_cls will be registered to
290-
specific frameworks and initialized.
293+
user_metric(lpot.experimental.common.Metric):
294+
user_metric should be object initialized from
295+
lpot.experimental.common.Metric, in this method the
296+
user_metric.metric_cls will be registered to
297+
specific frameworks and initialized.
291298
292299
"""
293300
from .common import Metric as LpotMetric
@@ -317,7 +324,8 @@ def postprocess(self, user_postprocess):
317324
318325
Args:
319326
user_postprocess(lpot.experimental.common.Postprocess):
320-
user_postprocess should be object initialized from lpot.experimental.common.Postprocess,
327+
user_postprocess should be object initialized from
328+
lpot.experimental.common.Postprocess,
321329
in this method the user_postprocess.postprocess_cls will be
322330
registered to specific frameworks and initialized.
323331
@@ -328,7 +336,8 @@ def postprocess(self, user_postprocess):
328336
postprocess_cfg = {user_postprocess.name : {**user_postprocess.kwargs}}
329337
if deep_get(self.conf.usr_cfg, "evaluation.accuracy.postprocess"):
330338
logger.warning('already set postprocess in yaml file, will override it...')
331-
deep_set(self.conf.usr_cfg, "evaluation.accuracy.postprocess.transform", postprocess_cfg)
339+
deep_set(
340+
self.conf.usr_cfg, "evaluation.accuracy.postprocess.transform", postprocess_cfg)
332341
from ..data import TRANSFORMS
333342
postprocesses = TRANSFORMS(self.framework, 'postprocess')
334343
postprocesses.register(user_postprocess.name, user_postprocess.postprocess_cls)

lpot/model/model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def __init__(self, model, framework_specific_info={}, **kwargs):
329329
self.workspace_path = deep_get(framework_specific_info, 'workspace_path', './')
330330
self.kwargs = copy.deepcopy(kwargs)
331331
kwargs.update({'name': deep_get(framework_specific_info, 'name')})
332-
# self.model = model
332+
self._model = model
333333
self.sess, self._input_tensor_names, self._output_tensor_names = \
334334
create_session_with_input_output(
335335
model, input_tensor_names, output_tensor_names, **kwargs)
@@ -427,24 +427,24 @@ def graph_def(self):
427427
@graph_def.setter
428428
def graph_def(self, graph_def):
429429
temp_saved_model_path = os.path.join(self.workspace_path, 'temp_saved_model')
430-
replace_graph_def_of_saved_model(self.model, temp_saved_model_path, graph_def)
431-
self.model = temp_saved_model_path
430+
replace_graph_def_of_saved_model(self._model, temp_saved_model_path, graph_def)
431+
self._model = temp_saved_model_path
432432
self.sess.close()
433433
self.sess, self._input_tensor_names, self._output_tensor_names = \
434-
create_session_with_input_output(self.model, \
434+
create_session_with_input_output(self._model, \
435435
self._input_tensor_names, self._output_tensor_names)
436436

437437
@property
438438
def graph(self):
439439
return self.sess.graph
440440

441441
def save(self, root):
442-
if root is not self.model:
442+
if root is not self._model:
443443
assert os.path.isdir(root), 'please supply the path to save the model....'
444444
import shutil
445-
file_names = os.listdir(self.model)
445+
file_names = os.listdir(self._model)
446446
for f in file_names:
447-
shutil.move(os.path.join(self.model, f), root)
447+
shutil.move(os.path.join(self._model, f), root)
448448

449449
# class TensorflowKerasModel(TensorflowBaseModel):
450450
#

test/test_tensorflow_graph_matmul_fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_matmul_biasadd_requantize_dequantize_fusion_with_softmax(self):
247247
output_graph = quantizer()
248248

249249
count=0
250-
for i in output_graph.model.node:
250+
for i in output_graph.model.as_graph_def().node:
251251
if i.op == 'QuantizedMatMulWithBiasAndDequantize':
252252
count += 1
253253
found_quantized_matmul = bool(count > 1)

0 commit comments

Comments
 (0)