Skip to content

Commit 7b5b447

Browse files
author
Matt Sokoloff
committed
wip
1 parent 9c57cc6 commit 7b5b447

File tree

3 files changed

+16
-28
lines changed

3 files changed

+16
-28
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ test-staging: build
66
docker run -it -v ${PWD}:/usr/src -w /usr/src \
77
-e LABELBOX_TEST_ENVIRON="staging" \
88
-e LABELBOX_TEST_API_KEY_STAGING=${LABELBOX_TEST_API_KEY_STAGING} \
9-
local/labelbox-python:test pytest $(PATH_TO_TEST) -svv
9+
local/labelbox-python:test pytest $(PATH_TO_TEST) -svvx
1010

1111
test-prod: build
1212
docker run -it -v ${PWD}:/usr/src -w /usr/src \

labelbox/data/metrics/iou.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def subclassification_miou(
101101
miou across all subclasses.
102102
"""
103103

104-
subclass_predictions = _create_schema_lookup(subclass_predictions)
105-
subclass_labels = _create_schema_lookup(subclass_labels)
104+
subclass_predictions = _create_name_lookup(subclass_predictions)
105+
subclass_labels = _create_name_lookup(subclass_labels)
106106
feature_schemas = set(subclass_predictions.keys()).union(
107107
set(subclass_labels.keys()))
108108
classification_iou = [
@@ -188,11 +188,7 @@ def feature_miou(predictions: List[Union[ObjectAnnotation,
188188
f"Unexpected annotation found. Found {type(predictions[0])}")
189189

190190

191-
def _create_schema_lookup(annotations: List[BaseAnnotation]):
192-
grouped_annotations = defaultdict(list)
193-
for annotation in annotations:
194-
grouped_annotations[annotation.schema_id] = annotation
195-
return grouped_annotations
191+
196192

197193

198194
def data_row_miou(ground_truth: Label,
@@ -210,11 +206,9 @@ def data_row_miou(ground_truth: Label,
210206
Returns:
211207
float indicating the iou score for this data row.
212208
"""
213-
annotation_types = None if include_classifications else Geometry
214-
prediction_annotations = predictions.get_annotations_by_attr(
215-
attr="name", annotation_types=annotation_types)
216-
ground_truth_annotations = ground_truth.get_annotations_by_attr(
217-
attr="name", annotation_types=annotation_types)
209+
210+
prediction_annotations = _create_name_lookup(predictions.annotations)
211+
ground_truth_annotations = _create_name_lookup(ground_truth.annotations)
218212
feature_schemas = set(prediction_annotations.keys()).union(
219213
set(ground_truth_annotations.keys()))
220214
ious = [
@@ -228,6 +222,11 @@ def data_row_miou(ground_truth: Label,
228222
return None
229223
return np.mean(ious)
230224

225+
def _create_name_lookup(annotations: List[BaseAnnotation]):
226+
grouped_annotations = defaultdict(list)
227+
for annotation in annotations:
228+
grouped_annotations[annotation.name] = annotation
229+
return grouped_annotations
231230

232231
def _get_vector_pairs(predictions: List[Geometry],
233232
ground_truths: List[Geometry]):
@@ -249,15 +248,3 @@ def _polygon_iou(poly1: Polygon, poly2: Polygon) -> float:
249248
def _mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
250249
"""Computes iou between two binary segmentation masks."""
251250
return np.sum(mask1 & mask2) / np.sum(mask1 | mask2)
252-
253-
254-
def _remove_opacity_channel(masks: List[np.ndarray]) -> List[np.ndarray]:
255-
return [mask[:, :, :3] if mask.shape[-1] == 4 else mask for mask in masks]
256-
257-
258-
def _instance_urls_to_binary_mask(urls: List[str],
259-
color: Tuple[int, int, int]) -> np.ndarray:
260-
"""Downloads segmentation masks and turns the image into a binary mask."""
261-
masks = _remove_opacity_channel([url_to_numpy(url) for url in urls])
262-
return np.sum([np.all(mask == color, axis=-1) for mask in masks],
263-
axis=0) > 0

tests/data/metrics/test_iou.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import numpy as np
55
import base64
66

7-
from labelbox.data.metrics.iou import datarow_miou
7+
from labelbox.data.metrics.iou import data_row_miou
8+
from labelbox.data.serialization import NDJsonConverter, LBV1Converter
89

910

1011
def check_iou(pair):
11-
assert datarow_miou(pair.labels, pair.predictions) == pair.expected
12+
assert data_row_miou(next(LBV1Converter.deserialize(pair.labels)), next(NDJsonConverter.deserialize(pair.predictions)) ) == pair.expected
1213

1314

1415
def strings_to_fixtures(strings):
@@ -72,5 +73,5 @@ def test_vector_with_subclass(pair):
7273

7374
@parametrize("pair", strings_to_fixtures(["point_pair", "line_pair"]))
7475
def test_others(pair):
75-
assert math.isclose(datarow_miou(pair.labels, pair.predictions),
76+
assert math.isclose(data_row_miou(pair.labels, pair.predictions),
7677
pair.expected)

0 commit comments

Comments
 (0)