Skip to content

Commit 60833fa

Browse files
ClarkChin08ftian1
authored andcommitted
[fix] fix blendcnn and coco import error
1 parent 0754a06 commit 60833fa

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

examples/pytorch/blendcnn/classify.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,16 @@ def eval_func(model):
196196
# print(f"Accuracy: {total_accuracy}")
197197

198198
if args.tune:
199-
import lpot
199+
from lpot.experimental import Quantization
200200
# lpot tune
201201
model.load_state_dict(torch.load(args.input_model))
202202
dataloader = Bert_DataLoader(loader=data_iter, batch_size=args.batch_size)
203203

204-
quantizer = lpot.Quantization(args.tuned_yaml)
205-
q_model = quantizer(model, q_dataloader=dataloader, eval_func=eval_func)
204+
quantizer = Quantization(args.tuned_yaml)
205+
quantizer.calib_dataloader = dataloader
206+
quantizer.model = model
207+
quantizer.eval_func = eval_func
208+
q_model = quantizer()
206209
q_model.save(args.tuned_checkpoint)
207210

208211
elif args.int8:

lpot/experimental/metric/metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ class TensorflowCOCOMAP(BaseMetric):
608608
609609
"""
610610
def __init__(self, anno_path=None):
611-
from lpot.metric.coco_label_map import category_map
611+
from .coco_label_map import category_map
612612
if anno_path:
613613
import os
614614
import json
@@ -631,7 +631,7 @@ def __init__(self, anno_path=None):
631631
[cat for cat in self.category_map]) #index
632632

633633
def update(self, predicts, labels, sample_weight=None):
634-
from lpot.metric.coco_tools import ExportSingleImageGroundtruthToCoco,\
634+
from .coco_tools import ExportSingleImageGroundtruthToCoco,\
635635
ExportSingleImageDetectionBoxesToCoco
636636
bbox, str_label,int_label, image_id = labels
637637
detection = {}
@@ -690,7 +690,7 @@ def reset(self):
690690
self.annotation_id = 1
691691

692692
def result(self):
693-
from lpot.metric.coco_tools import COCOWrapper, COCOEvalWrapper
693+
from .coco_tools import COCOWrapper, COCOEvalWrapper
694694
if len(self.ground_truth_list) == 0:
695695
logger.warning("sample num is 0 can't calculate mAP")
696696
return 0

0 commit comments

Comments
 (0)