Skip to content

Commit 8273daf

Browse files
author
Matt Sokoloff
committed
add mask test back
1 parent 0b1e5c3 commit 8273daf

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

labelbox/data/metrics/iou.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from labelbox.data import annotation_types
1313

1414

15-
def mask_miou(predictions: List[Mask],
16-
ground_truths: List[Mask],
15+
def mask_miou(predictions: List[ObjectAnnotation],
16+
ground_truths: List[ObjectAnnotation],
1717
resize_height=None,
1818
resize_width=None) -> float:
1919
"""
@@ -27,13 +27,15 @@ def mask_miou(predictions: List[Mask],
2727
Returns:
2828
float indicating iou score
2929
"""
30+
#TODO: Filter out non-masks object annotations maybe..
31+
3032
prediction_np = np.max([
31-
pred.raster(binary=True, height=resize_height, width=resize_width)
33+
pred.value.raster(binary=True, height=resize_height, width=resize_width)
3234
for pred in predictions
3335
],
3436
axis=0)
3537
ground_truth_np = np.max([
36-
ground_truth.raster(
38+
ground_truth.value.raster(
3739
binary=True, height=resize_height, width=resize_width)
3840
for ground_truth in ground_truths
3941
],

tests/data/metrics/test_iou.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,20 @@
66

77
from labelbox.data.metrics.iou import data_row_miou
88
from labelbox.data.serialization import NDJsonConverter, LBV1Converter
9-
from labelbox.data.annotation_types import Label, RasterData
9+
from labelbox.data.annotation_types import Label, RasterData, Mask
1010

1111

12-
def check_iou(pair):
12+
def check_iou(pair, mask=False):
1313
default = Label(data=RasterData(uid="ckppihxc10005aeyjen11h7jh"))
14-
assert math.isclose(
15-
data_row_miou(
16-
next(LBV1Converter.deserialize([pair.labels])),
17-
next(NDJsonConverter.deserialize(pair.predictions), default)),
18-
pair.expected)
14+
prediction = next(NDJsonConverter.deserialize(pair.predictions), default)
15+
label = next(LBV1Converter.deserialize([pair.labels]))
16+
if mask:
17+
for annotation in [*prediction.annotations, *label.annotations]:
18+
if isinstance(annotation.value, Mask):
19+
annotation.value.mask.arr = np.frombuffer(
20+
base64.b64decode(annotation.value.mask.url.encode('utf-8')),
21+
dtype=np.uint8).reshape((32, 32, 3))
22+
assert math.isclose(data_row_miou(label, prediction), pair.expected)
1923

2024

2125
def strings_to_fixtures(strings):
@@ -25,11 +29,7 @@ def strings_to_fixtures(strings):
2529
def test_overlapping(polygon_pair, box_pair, mask_pair):
2630
check_iou(polygon_pair)
2731
check_iou(box_pair)
28-
#with patch('labelbox.data.metrics.iou.url_to_numpy',
29-
# side_effect=lambda x: np.frombuffer(
30-
# base64.b64decode(x.encode('utf-8')), dtype=np.uint8).reshape(
31-
# (32, 32, 3))):
32-
# #check_iou(mask_pair)
32+
check_iou(mask_pair, True)
3333

3434

3535
@parametrize("pair",

0 commit comments

Comments
 (0)