6
6
7
7
from labelbox .data .metrics .iou import data_row_miou
8
8
from labelbox .data .serialization import NDJsonConverter , LBV1Converter
9
- from labelbox .data .annotation_types import Label , RasterData
9
+ from labelbox .data .annotation_types import Label , RasterData , Mask
10
10
11
11
12
- def check_iou (pair ):
12
+ def check_iou (pair , mask = False ):
13
13
default = Label (data = RasterData (uid = "ckppihxc10005aeyjen11h7jh" ))
14
- assert math .isclose (
15
- data_row_miou (
16
- next (LBV1Converter .deserialize ([pair .labels ])),
17
- next (NDJsonConverter .deserialize (pair .predictions ), default )),
18
- pair .expected )
14
+ prediction = next (NDJsonConverter .deserialize (pair .predictions ), default )
15
+ label = next (LBV1Converter .deserialize ([pair .labels ]))
16
+ if mask :
17
+ for annotation in [* prediction .annotations , * label .annotations ]:
18
+ if isinstance (annotation .value , Mask ):
19
+ annotation .value .mask .arr = np .frombuffer (
20
+ base64 .b64decode (annotation .value .mask .url .encode ('utf-8' )),
21
+ dtype = np .uint8 ).reshape ((32 , 32 , 3 ))
22
+ assert math .isclose (data_row_miou (label , prediction ), pair .expected )
19
23
20
24
21
25
def strings_to_fixtures (strings ):
@@ -25,11 +29,7 @@ def strings_to_fixtures(strings):
25
29
def test_overlapping (polygon_pair , box_pair , mask_pair ):
26
30
check_iou (polygon_pair )
27
31
check_iou (box_pair )
28
- #with patch('labelbox.data.metrics.iou.url_to_numpy',
29
- # side_effect=lambda x: np.frombuffer(
30
- # base64.b64decode(x.encode('utf-8')), dtype=np.uint8).reshape(
31
- # (32, 32, 3))):
32
- # #check_iou(mask_pair)
32
+ check_iou (mask_pair , True )
33
33
34
34
35
35
@parametrize ("pair" ,
0 commit comments