-
Notifications
You must be signed in to change notification settings - Fork 772
Bug Fixes for PTQ and ACQ based OpenVINO Model Export and Added Test Cases #2594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 39 commits
b0fc113
97e9fa2
9c744eb
7f7cf8b
05ab70d
8faa2c6
c08eb45
cad5731
9a0c7cd
7cbf7e0
a02aa37
87dde1c
5d2a3bc
7888e12
3bf7af8
1fbf2fb
1f4948d
af8c734
2233dbd
96860c7
d8235c8
42fc4d5
e55b6ca
177baaf
e0c8cc5
ad3ac99
65a58ad
c9f971c
835bf32
2dc5a19
0bb9fd8
01c6c05
7d27a29
9293bd7
a30744e
9dae140
0737ddc
1425059
8a6ec1f
2ced95b
10020b5
2b06dd4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -321,7 +321,7 @@ def _post_training_quantization_ov( | |
f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", | ||
) | ||
|
||
calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) | ||
calibration_dataset = nncf.Dataset(dataloader, lambda x: x.image) | ||
return nncf.quantize(model, calibration_dataset) | ||
|
||
@staticmethod | ||
|
@@ -364,6 +364,11 @@ def _accuracy_control_quantization_ov( | |
msg = "Metric must be provided for OpenVINO INT8_ACQ compression" | ||
raise ValueError(msg) | ||
|
||
# Setting up the fields parameter in Metric if Metric is initialized with placeholder. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you mean without a placeholder? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What Iβm trying to convey is this: Letβs say the user passes a metric via the CLI using I also noticed a couple of grammatical mistakes in an inline comment and a logger statement:
I hope this clarifies your question. Let me know if you have any further doubtsβIβd be happy to help! Thank you. |
||
if metric.fields[0] == "": | ||
metric.fields = ("anomaly_map", "gt_mask") if task == TaskType.SEGMENTATION else ("pred_score", "gt_label") | ||
logger.info(f"The fields of metric are initialized empty. Setting it to model fields {metric.fields}") | ||
|
||
model_input = model.input(0) | ||
|
||
if model_input.partial_shape[0].is_static: | ||
|
@@ -376,15 +381,20 @@ def _accuracy_control_quantization_ov( | |
f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", | ||
) | ||
|
||
calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) | ||
calibration_dataset = nncf.Dataset(dataloader, lambda x: x.image) | ||
validation_dataset = nncf.Dataset(datamodule.test_dataloader()) | ||
|
||
# validation function to evaluate the quality loss after quantization | ||
def val_fn(nncf_model: "CompiledModel", validation_data: Iterable) -> float: | ||
for batch in validation_data: | ||
preds = torch.from_numpy(nncf_model(batch["image"])[0]) | ||
target = batch["label"] if task == TaskType.CLASSIFICATION else batch["mask"][:, None, :, :] | ||
metric.update(preds, target) | ||
preds = nncf_model(batch.image) | ||
for key, pred in preds.items(): | ||
name = key.get_any_name() | ||
setattr(batch, name, torch.from_numpy(pred)) | ||
if batch.gt_mask is not None: | ||
batch.gt_mask = batch.gt_mask.unsqueeze(dim=1) | ||
batch.pred_score = batch.pred_score.squeeze(dim=1) # Squeezing since it is binary. (B, 1) -> (B) | ||
metric.update(batch) | ||
return metric.compute() | ||
|
||
return nncf.quantize_with_accuracy_control(model, calibration_dataset, validation_dataset, val_fn) | ||
|
Uh oh!
There was an error while loading. Please reload this page.