Skip to content

Commit 138e2bf

Browse files
authored
Merge pull request #178 from Labelbox/ms/mea-seg-metrics
mult-class seg iou
2 parents faa0bd6 + 27662c3 commit 138e2bf

File tree

4 files changed

+40
-20
lines changed

4 files changed

+40
-20
lines changed

labelbox/data/metrics/iou.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
def mask_miou(predictions: List[NDMask], labels: List[NDMask]) -> float:
2121
"""
22-
Creates prediction and label binary mask for all features with the same feature scheama id.
22+
Creates prediction and label binary mask for all features with the same feature schema id.
2323
2424
Args:
2525
predictions: List of masks objects
@@ -28,10 +28,18 @@ def mask_miou(predictions: List[NDMask], labels: List[NDMask]) -> float:
2828
float indicating iou score
2929
"""
3030

31+
colors_pred = {tuple(pred.mask['colorRGB']) for pred in predictions}
32+
colors_label = {tuple(label.mask['colorRGB']) for label in labels}
33+
error_msg = "segmentation {} should all have the same color. Found {}"
34+
if len(colors_pred) > 1:
35+
raise ValueError(error_msg.format("predictions", colors_pred))
36+
elif len(colors_label) > 1:
37+
raise ValueError(error_msg.format("labels", colors_label))
38+
3139
pred_mask = _instance_urls_to_binary_mask(
32-
[pred.mask['instanceURI'] for pred in predictions])
40+
[pred.mask['instanceURI'] for pred in predictions], colors_pred.pop())
3341
label_mask = _instance_urls_to_binary_mask(
34-
[label.mask['instanceURI'] for label in labels])
42+
[label.mask['instanceURI'] for label in labels], colors_label.pop())
3543
assert label_mask.shape == pred_mask.shape
3644
return _mask_iou(label_mask, pred_mask)
3745

@@ -282,7 +290,13 @@ def _mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
282290
return np.sum(mask1 & mask2) / np.sum(mask1 | mask2)
283291

284292

285-
def _instance_urls_to_binary_mask(urls: List[str]) -> np.ndarray:
293+
def _remove_opacity_channel(masks: List[np.ndarray]) -> List[np.ndarray]:
294+
return [mask[:, :, :3] if mask.shape[-1] == 4 else mask for mask in masks]
295+
296+
297+
def _instance_urls_to_binary_mask(urls: List[str],
298+
color: Tuple[int, int, int]) -> np.ndarray:
286299
"""Downloads segmentation masks and turns the image into a binary mask."""
287-
masks = [url_to_numpy(url) for url in urls]
288-
return np.sum(masks, axis=(0, 3)) > 0
300+
masks = _remove_opacity_channel([url_to_numpy(url) for url in urls])
301+
return np.sum([np.all(mask == color, axis=-1) for mask in masks],
302+
axis=0) > 0

labelbox/data/metrics/preprocess.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
from typing import List, Dict, Any, Union
23
from collections import defaultdict
34
import numpy as np # type: ignore
@@ -35,7 +36,8 @@ def label_to_ndannotation(label: Dict[str, Any],
3536
if tool in SEGMENTATION_TOOLS:
3637
label['mask'] = {
3738
'instanceURI': label['instanceURI'],
38-
'colorRGB': (0, 0, 0)
39+
# Matches the color in the seg masks in the exports
40+
'colorRGB': (255, 255, 255)
3941
}
4042
for unused_key in unused_keys:
4143
label.pop(unused_key, None)
@@ -66,7 +68,8 @@ def create_schema_lookup(rows: List[NDBase]) -> Dict[str, List[Any]]:
6668
return data
6769

6870

69-
@retry.Retry(deadline=10.)
71+
@retry.Retry(deadline=15.)
72+
@functools.lru_cache(maxsize=256)
7073
def url_to_numpy(mask_url: str) -> np.ndarray:
7174
""" Downloads an image and converts to a numpy array """
7275
return np.array(Image.open(BytesIO(requests.get(mask_url).content)))

tests/data/metrics/conftest.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
import numpy as np
55
from PIL import Image
6+
import base64
67

78

89
class NameSpace(SimpleNamespace):
@@ -225,21 +226,22 @@ def unmatched_label():
225226
expected=0.25)
226227

227228

228-
def create_mask_url(indices, h, w):
229+
def create_mask_url(indices, h, w, value):
229230
mask = np.zeros((h, w, 3), dtype=np.uint8)
230231
for idx in indices:
231-
mask[idx] = 1
232-
return mask.tobytes()
232+
mask[idx] = value
233+
return base64.b64encode(mask.tobytes()).decode('utf-8')
233234

234235

235236
@pytest.fixture
236237
def mask_pair():
237-
#* Use your own signed urls so that you can resign the data
238-
#* This is just to make the demo work
239238
return NameSpace(labels=[{
240-
'featureId': '1234567890111213141516171',
241-
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
242-
'instanceURI': create_mask_url([(0, 0, 0), (0, 1, 0)], 32, 32)
239+
'featureId':
240+
'1234567890111213141516171',
241+
'schemaId':
242+
'ckppid25v0000aeyjmxfwlc7t',
243+
'instanceURI':
244+
create_mask_url([(0, 0), (0, 1)], 32, 32, (255, 255, 255))
243245
}],
244246
predictions=[{
245247
'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a',
@@ -249,7 +251,7 @@ def mask_pair():
249251
},
250252
'mask': {
251253
'instanceURI':
252-
create_mask_url([(0, 0, 0)], 32, 32),
254+
create_mask_url([(0, 0)], 32, 32, (1, 1, 1)),
253255
'colorRGB': (1, 1, 1)
254256
}
255257
}],

tests/data/metrics/test_iou.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import patch
33
import math
44
import numpy as np
5+
import base64
56

67
from labelbox.data.metrics.iou import datarow_miou
78

@@ -18,9 +19,9 @@ def test_overlapping(polygon_pair, box_pair, mask_pair):
1819
check_iou(polygon_pair)
1920
check_iou(box_pair)
2021
with patch('labelbox.data.metrics.iou.url_to_numpy',
21-
side_effect=lambda x: np.frombuffer(x.encode('utf-8'),
22-
dtype=np.uint8).reshape(
23-
(32, 32, 3))):
22+
side_effect=lambda x: np.frombuffer(
23+
base64.b64decode(x.encode('utf-8')), dtype=np.uint8).reshape(
24+
(32, 32, 3))):
2425
check_iou(mask_pair)
2526

2627

0 commit comments

Comments
 (0)