Skip to content

Commit 0b1945e

Browse files
committed
address PR comments 2
1 parent 37f1a92 commit 0b1945e

File tree

4 files changed

+22
-21
lines changed

4 files changed

+22
-21
lines changed

nucleus/model_run.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,25 +126,13 @@ def loc(self, dataset_item_id: str):
126126
def _format_prediction_response(
127127
self, response: dict
128128
) -> Union[dict, List[Union[BoxPrediction, PolygonPrediction]]]:
129-
annotations = response.get(ANNOTATIONS_KEY, None)
130-
if annotations:
129+
annotation_payload = response.get(ANNOTATIONS_KEY, None)
130+
if annotation_payload:
131131
annotation_response = {}
132-
if BOX_TYPE in annotations:
133-
annotation_response[BOX_TYPE] = [
134-
BoxPrediction.from_json(ann)
135-
for ann in annotations[BOX_TYPE]
136-
]
137-
if POLYGON_TYPE in annotations:
138-
annotation_response[POLYGON_TYPE] = [
139-
PolygonPrediction.from_json(ann)
140-
for ann in annotations[POLYGON_TYPE]
141-
]
142-
if SEGMENTATION_TYPE in annotations:
143-
annotation_response[
144-
SEGMENTATION_TYPE
145-
] = SegmentationPrediction.from_json(
146-
annotations[SEGMENTATION_TYPE]
147-
)
132+
for (type_key, type_cls) in zip([BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE],
133+
[BoxPrediction, PolygonPrediction, SegmentationPrediction]):
134+
if type_key in annotation_payload:
135+
annotation_response[type_key] = [type_cls.from_json(ann) for ann in annotation_payload[type_key]]
148136
return annotation_response
149137
else: # An error occurred
150138
return response

tests/helpers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def assert_segmentation_annotation_matches_dict(
156156
annotation_dict["annotations"]
157157
)
158158

159+
for instance_segment, dict_segment in zip(sorted(annotation_instance.annotations, key = lambda i: i.index), sorted(annotation_dict["annotations"], key = lambda i: i["index"])):
160+
assert instance_segment.index == dict_segment["index"]
161+
assert instance_segment.label == dict_segment["label"]
162+
159163

160164
# Asserts that a box prediction instance matches a dict representing its properties.
161165
# Useful to check prediction uploads/updates match.

tests/test_annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_single_semseg_gt_upload(dataset):
8686

8787
response_annotation = dataset.refloc(annotation.reference_id)[
8888
"annotations"
89-
]["segmentation"]
89+
]["segmentation"][0]
9090
assert_segmentation_annotation_matches_dict(
9191
response_annotation, TEST_SEGMENTATION_ANNOTATIONS[0]
9292
)

tests/test_prediction.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ def test_segmentation_pred_upload(model_run):
9292
assert response["predictions_ignored"] == 0
9393

9494
response = model_run.refloc(prediction.reference_id)["segmentation"]
95-
assert isinstance(response, SegmentationPrediction)
95+
assert isinstance(response[0], SegmentationPrediction)
9696

9797
assert_segmentation_annotation_matches_dict(
98-
response, TEST_SEGMENTATION_PREDICTIONS[0]
98+
response[0], TEST_SEGMENTATION_PREDICTIONS[0]
9999
)
100100

101101

@@ -238,3 +238,12 @@ def test_mixed_pred_upload(model_run):
238238
assert response["predictions_ignored"] == 0
239239

240240
response_refloc = model_run.refloc(prediction_polygon.reference_id)
241+
assert_box_prediction_matches_dict(
242+
response_refloc["box"][0], TEST_BOX_PREDICTIONS[0]
243+
)
244+
assert_polygon_prediction_matches_dict(
245+
response_refloc["polygon"][0], TEST_POLYGON_PREDICTIONS[0]
246+
)
247+
assert_segmentation_annotation_matches_dict(
248+
response_refloc["segmentation"][0], TEST_SEGMENTATION_PREDICTIONS[0]
249+
)

0 commit comments

Comments
 (0)