Skip to content

Commit 40badb6

Browse files
committed
incorporate additional tests
1 parent b554529 commit 40badb6

File tree

8 files changed

+119
-57
lines changed

8 files changed

+119
-57
lines changed

nucleus/annotation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
ANNOTATIONS_KEY,
2222
)
2323

24+
ORIGINAL_URL_KEY = "original_url"
25+
2426

2527
class Segment:
2628
def __init__(
@@ -74,7 +76,7 @@ def __str__(self):
7476
@classmethod
7577
def from_json(cls, payload: dict):
7678
return cls(
77-
mask_url=payload[MASK_URL_KEY],
79+
mask_url=payload.get(MASK_URL_KEY, None),
7880
annotations=[
7981
Segment.from_json(ann)
8082
for ann in payload.get(ANNOTATIONS_KEY, [])

nucleus/constants.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
2-
NUCLEUS_ENDPOINT = "http://localhost:3000/v1/nucleus"
1+
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
32
ITEMS_KEY = "items"
43
ITEM_KEY = "item"
54
REFERENCE_ID_KEY = "reference_id"

nucleus/dataset.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -253,16 +253,11 @@ def _format_dataset_item_response(self, response: dict) -> dict:
253253
for ann in annotation_payload[POLYGON_TYPE]
254254
]
255255
if SEGMENTATION_TYPE in annotation_payload:
256-
annotation_response[SEGMENTATION_TYPE] = [
257-
SegmentationAnnotation.from_json(ann)
258-
for ann in annotation_payload[SEGMENTATION_TYPE]
259-
]
260-
# annotations = [
261-
# BoxAnnotation.from_json(ann)
262-
# if ann["type"] == "box"
263-
# else PolygonAnnotation.from_json(ann)
264-
# for ann in annotation_payload
265-
# ]
256+
annotation_response[
257+
SEGMENTATION_TYPE
258+
] = SegmentationAnnotation.from_json(
259+
annotation_payload[SEGMENTATION_TYPE]
260+
)
266261
return {
267262
ITEM_KEY: DatasetItem.from_json(item),
268263
ANNOTATIONS_KEY: annotation_response,

nucleus/model_run.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
from typing import Optional, List, Dict, Any, Union
2-
from .constants import ANNOTATIONS_KEY, DEFAULT_ANNOTATION_UPDATE_MODE
2+
from .constants import (
3+
ANNOTATIONS_KEY,
4+
DEFAULT_ANNOTATION_UPDATE_MODE,
5+
BOX_TYPE,
6+
POLYGON_TYPE,
7+
SEGMENTATION_TYPE,
8+
)
39
from .prediction import (
410
BoxPrediction,
511
PolygonPrediction,
@@ -122,12 +128,26 @@ def _format_prediction_response(
122128
self, response: dict
123129
) -> Union[dict, List[Union[BoxPrediction, PolygonPrediction]]]:
124130
annotations = response.get(ANNOTATIONS_KEY, None)
131+
print(annotations)
125132
if annotations:
126-
return [
127-
BoxPrediction.from_json(ann)
128-
if ann["type"] == "box"
129-
else PolygonPrediction.from_json(ann)
130-
for ann in annotations
131-
]
133+
annotation_response = {}
134+
if BOX_TYPE in annotations:
135+
annotation_response[BOX_TYPE] = [
136+
BoxPrediction.from_json(ann)
137+
for ann in annotations[BOX_TYPE]
138+
]
139+
if POLYGON_TYPE in annotations:
140+
annotation_response[POLYGON_TYPE] = [
141+
PolygonPrediction.from_json(ann)
142+
for ann in annotations[POLYGON_TYPE]
143+
]
144+
if SEGMENTATION_TYPE in annotations:
145+
annotation_response[
146+
SEGMENTATION_TYPE
147+
] = SegmentationPrediction.from_json(
148+
annotations[SEGMENTATION_TYPE]
149+
)
150+
print(annotation_response)
151+
return annotation_response
132152
else: # An error occurred
133153
return response

tests/helpers.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
TEST_DATASET_NAME = "[PyTest] Test Dataset"
1212
TEST_SLICE_NAME = "[PyTest] Test Slice"
1313

14-
TEST_MODEL_NAME = '[PyTest] Test Model Name'
15-
TEST_MODEL_REFERENCE = '[PyTest] Test Model Reference'
16-
TEST_MODEL_RUN = '[PyTest] Test Model Run Reference'
17-
TEST_DATASET_NAME = '[PyTest] Test Dataset'
18-
TEST_SLICE_NAME = '[PyTest] Test Slice'
14+
TEST_MODEL_NAME = "[PyTest] Test Model Name"
15+
TEST_MODEL_REFERENCE = "[PyTest] Test Model Reference"
16+
TEST_MODEL_RUN = "[PyTest] Test Model Run Reference"
17+
TEST_DATASET_NAME = "[PyTest] Test Dataset"
18+
TEST_SLICE_NAME = "[PyTest] Test Slice"
1919
TEST_IMG_URLS = [
2020
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/6dd63871-831611a6.jpg",
2121
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/82c1005c-e2d1d94f.jpg",
@@ -24,16 +24,16 @@
2424
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/89b42832-10d662f4.jpg",
2525
]
2626
TEST_DATASET_ITEMS = [
27-
DatasetItem(TEST_IMG_URLS[0], '1'),
28-
DatasetItem(TEST_IMG_URLS[1], '2'),
29-
DatasetItem(TEST_IMG_URLS[2], '3'),
30-
DatasetItem(TEST_IMG_URLS[3], '4')
27+
DatasetItem(TEST_IMG_URLS[0], "1"),
28+
DatasetItem(TEST_IMG_URLS[1], "2"),
29+
DatasetItem(TEST_IMG_URLS[2], "3"),
30+
DatasetItem(TEST_IMG_URLS[3], "4"),
3131
]
3232
TEST_PREDS = [
33-
BoxPrediction('[Pytest Box Prediction 1]', 0, 0, 100, 100, '1'),
34-
BoxPrediction('[Pytest Box Prediction 2]', 0, 0, 100, 100, '2'),
35-
BoxPrediction('[Pytest Box Prediction 3]', 0, 0, 100, 100, '3'),
36-
BoxPrediction('[Pytest Box Prediction 4]', 0, 0, 100, 100, '4')
33+
BoxPrediction("[Pytest Box Prediction 1]", 0, 0, 100, 100, "1"),
34+
BoxPrediction("[Pytest Box Prediction 2]", 0, 0, 100, 100, "2"),
35+
BoxPrediction("[Pytest Box Prediction 3]", 0, 0, 100, 100, "3"),
36+
BoxPrediction("[Pytest Box Prediction 4]", 0, 0, 100, 100, "4"),
3737
]
3838

3939

@@ -147,6 +147,16 @@ def assert_polygon_annotation_matches_dict(
147147
assert instance_pt["y"] == dict_pt["y"]
148148

149149

150+
def assert_segmentation_annotation_matches_dict(
151+
annotation_instance, annotation_dict
152+
):
153+
assert annotation_instance.mask_url == annotation_dict["mask_url"]
154+
# Cannot guarantee segments are in same order
155+
assert len(annotation_instance.annotations) == len(
156+
annotation_dict["annotations"]
157+
)
158+
159+
150160
# Asserts that a box prediction instance matches a dict representing its properties.
151161
# Useful to check prediction uploads/updates match.
152162
def assert_box_prediction_matches_dict(prediction_instance, prediction_dict):

tests/test_annotation.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from helpers import (
3+
from .helpers import (
44
TEST_DATASET_NAME,
55
TEST_IMG_URLS,
66
TEST_BOX_ANNOTATIONS,
@@ -9,6 +9,7 @@
99
reference_id_from_url,
1010
assert_box_annotation_matches_dict,
1111
assert_polygon_annotation_matches_dict,
12+
assert_segmentation_annotation_matches_dict,
1213
)
1314

1415
from nucleus import (
@@ -48,7 +49,7 @@ def test_box_gt_upload(dataset):
4849
assert response["annotations_processed"] == 1
4950
assert response["annotations_ignored"] == 0
5051

51-
response = dataset.refloc(annotation.reference_id)["annotations"]
52+
response = dataset.refloc(annotation.reference_id)["annotations"]["box"]
5253
assert len(response) == 1
5354
response_annotation = response[0]
5455
assert_box_annotation_matches_dict(
@@ -64,7 +65,9 @@ def test_polygon_gt_upload(dataset):
6465
assert response["annotations_processed"] == 1
6566
assert response["annotations_ignored"] == 0
6667

67-
response = dataset.refloc(annotation.reference_id)["annotations"]
68+
response = dataset.refloc(annotation.reference_id)["annotations"][
69+
"polygon"
70+
]
6871
assert len(response) == 1
6972
response_annotation = response[0]
7073
assert_polygon_annotation_matches_dict(
@@ -80,7 +83,13 @@ def test_single_semseg_gt_upload(dataset):
8083
assert response["dataset_id"] == dataset.id
8184
assert response["annotations_processed"] == 1
8285
assert response["annotations_ignored"] == 0
83-
# assert_box_annotation_matches_dict(response_annotation, TEST_BOX_ANNOTATIONS[0])
86+
87+
response_annotation = dataset.refloc(annotation.reference_id)[
88+
"annotations"
89+
]["segmentation"]
90+
assert_segmentation_annotation_matches_dict(
91+
response_annotation, TEST_SEGMENTATION_ANNOTATIONS[0]
92+
)
8493

8594

8695
def test_batch_semseg_gt_upload(dataset):
@@ -142,6 +151,12 @@ def test_mixed_annotation_upload(dataset):
142151
assert response["dataset_id"] == dataset.id
143152
assert response["annotations_processed"] == 10
144153
assert response["annotations_ignored"] == 0
154+
response_annotations = dataset.refloc(bbox_annotations[0].reference_id)[
155+
"annotations"
156+
]
157+
assert len(response_annotations) == 2
158+
assert len(response_annotations["box"]) == 1
159+
assert "segmentation" in response_annotations
145160

146161

147162
def test_box_gt_upload_update(dataset):
@@ -165,7 +180,7 @@ def test_box_gt_upload_update(dataset):
165180
assert response["annotations_processed"] == 1
166181
assert response["annotations_ignored"] == 0
167182

168-
response = dataset.refloc(annotation.reference_id)["annotations"]
183+
response = dataset.refloc(annotation.reference_id)["annotations"]["box"]
169184
assert len(response) == 1
170185
response_annotation = response[0]
171186
assert_box_annotation_matches_dict(
@@ -194,7 +209,7 @@ def test_box_gt_upload_ignore(dataset):
194209
assert response["annotations_processed"] == 1
195210
assert response["annotations_ignored"] == 1
196211

197-
response = dataset.refloc(annotation.reference_id)["annotations"]
212+
response = dataset.refloc(annotation.reference_id)["annotations"]["box"]
198213
assert len(response) == 1
199214
response_annotation = response[0]
200215
assert_box_annotation_matches_dict(
@@ -223,7 +238,9 @@ def test_polygon_gt_upload_update(dataset):
223238
assert response["annotations_processed"] == 1
224239
assert response["annotations_ignored"] == 0
225240

226-
response = dataset.refloc(annotation.reference_id)["annotations"]
241+
response = dataset.refloc(annotation.reference_id)["annotations"][
242+
"polygon"
243+
]
227244
assert len(response) == 1
228245
response_annotation = response[0]
229246
assert_polygon_annotation_matches_dict(
@@ -253,7 +270,9 @@ def test_polygon_gt_upload_ignore(dataset):
253270
assert response["annotations_processed"] == 1
254271
assert response["annotations_ignored"] == 1
255272

256-
response = dataset.refloc(annotation.reference_id)["annotations"]
273+
response = dataset.refloc(annotation.reference_id)["annotations"][
274+
"polygon"
275+
]
257276
assert len(response) == 1
258277
response_annotation = response[0]
259278
assert_polygon_annotation_matches_dict(

tests/test_dataset.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from helpers import (
3+
from .helpers import (
44
TEST_SLICE_NAME,
55
TEST_DATASET_NAME,
66
TEST_IMG_URLS,
@@ -17,6 +17,7 @@
1717
DATASET_ID_KEY,
1818
)
1919

20+
2021
@pytest.fixture()
2122
def dataset(CLIENT):
2223
ds = CLIENT.create_dataset(TEST_DATASET_NAME)
@@ -25,6 +26,7 @@ def dataset(CLIENT):
2526
response = CLIENT.delete_dataset(ds.id)
2627
assert response == {}
2728

29+
2830
def test_dataset_create_and_delete(CLIENT):
2931
# Creation
3032
ds = CLIENT.create_dataset(TEST_DATASET_NAME)
@@ -88,14 +90,17 @@ def test_dataset_list_autotags(CLIENT, dataset):
8890
autotag_response = CLIENT.list_autotags(dataset.id)
8991
assert autotag_response == []
9092

93+
9194
def test_slice_create_and_delete(dataset):
9295
# Dataset upload
9396
ds_items = []
9497
for url in TEST_IMG_URLS:
95-
ds_items.append(DatasetItem(
96-
image_location=url,
97-
reference_id=reference_id_from_url(url),
98-
))
98+
ds_items.append(
99+
DatasetItem(
100+
image_location=url,
101+
reference_id=reference_id_from_url(url),
102+
)
103+
)
99104
response = dataset.append(ds_items)
100105
assert ERROR_PAYLOAD not in response.json()
101106

@@ -121,10 +126,12 @@ def test_slice_append(dataset):
121126
# Dataset upload
122127
ds_items = []
123128
for url in TEST_IMG_URLS:
124-
ds_items.append(DatasetItem(
125-
image_location=url,
126-
reference_id=reference_id_from_url(url),
127-
))
129+
ds_items.append(
130+
DatasetItem(
131+
image_location=url,
132+
reference_id=reference_id_from_url(url),
133+
)
134+
)
128135
response = dataset.append(ds_items)
129136
assert ERROR_PAYLOAD not in response.json()
130137

0 commit comments

Comments
 (0)