Skip to content

Commit 6399ccd

Browse files
authored
Fix Nucleus export_predictions_generator not returning confidences (#390)
1 parent 0791384 commit 6399ccd

File tree

9 files changed

+70
-27
lines changed

9 files changed

+70
-27
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.15.10](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.15.10) - 2023-07-20
9+
10+
### Added
11+
- Fix `slice.export_predictions(args)` and `slice.export_predictions_generator(args)` methods to return `Predictions` instead of `Annotations`
12+
813
## [0.15.9](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.15.9) - 2023-06-26
914

1015
### Added

nucleus/utils.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def format_scale_task_info_response(response: dict) -> Union[Dict, List[Dict]]:
221221
return ret
222222

223223

224+
# pylint: disable=too-many-branches
224225
def convert_export_payload(api_payload, has_predictions: bool = False):
225226
"""Helper function to convert raw JSON to API objects
226227
@@ -239,33 +240,66 @@ def convert_export_payload(api_payload, has_predictions: bool = False):
239240
if row.get(SEGMENTATION_TYPE) is not None:
240241
segmentation = row[SEGMENTATION_TYPE]
241242
segmentation[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
242-
annotations[SEGMENTATION_TYPE] = SegmentationAnnotation.from_json(
243-
segmentation
244-
)
243+
if not has_predictions:
244+
annotations[
245+
SEGMENTATION_TYPE
246+
] = SegmentationAnnotation.from_json(segmentation)
247+
else:
248+
annotations[
249+
SEGMENTATION_TYPE
250+
] = SegmentationPrediction.from_json(segmentation)
245251
for polygon in row[POLYGON_TYPE]:
246252
polygon[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
247-
annotations[POLYGON_TYPE].append(
248-
PolygonAnnotation.from_json(polygon)
249-
)
253+
if not has_predictions:
254+
annotations[POLYGON_TYPE].append(
255+
PolygonAnnotation.from_json(polygon)
256+
)
257+
else:
258+
annotations[POLYGON_TYPE].append(
259+
PolygonPrediction.from_json(polygon)
260+
)
250261
for line in row[LINE_TYPE]:
251262
line[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
252-
annotations[LINE_TYPE].append(LineAnnotation.from_json(line))
263+
if not has_predictions:
264+
annotations[LINE_TYPE].append(LineAnnotation.from_json(line))
265+
else:
266+
annotations[LINE_TYPE].append(LinePrediction.from_json(line))
253267
for keypoints in row[KEYPOINTS_TYPE]:
254268
keypoints[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
255-
annotations[KEYPOINTS_TYPE].append(
256-
KeypointsAnnotation.from_json(keypoints)
257-
)
269+
if not has_predictions:
270+
annotations[KEYPOINTS_TYPE].append(
271+
KeypointsAnnotation.from_json(keypoints)
272+
)
273+
else:
274+
annotations[KEYPOINTS_TYPE].append(
275+
KeypointsPrediction.from_json(keypoints)
276+
)
258277
for box in row[BOX_TYPE]:
259278
box[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
260-
annotations[BOX_TYPE].append(BoxAnnotation.from_json(box))
279+
if not has_predictions:
280+
annotations[BOX_TYPE].append(BoxAnnotation.from_json(box))
281+
else:
282+
annotations[BOX_TYPE].append(BoxPrediction.from_json(box))
261283
for cuboid in row[CUBOID_TYPE]:
262284
cuboid[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
263-
annotations[CUBOID_TYPE].append(CuboidAnnotation.from_json(cuboid))
285+
if not has_predictions:
286+
annotations[CUBOID_TYPE].append(
287+
CuboidAnnotation.from_json(cuboid)
288+
)
289+
else:
290+
annotations[CUBOID_TYPE].append(
291+
CuboidPrediction.from_json(cuboid)
292+
)
264293
for category in row[CATEGORY_TYPE]:
265294
category[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
266-
annotations[CATEGORY_TYPE].append(
267-
CategoryAnnotation.from_json(category)
268-
)
295+
if not has_predictions:
296+
annotations[CATEGORY_TYPE].append(
297+
CategoryAnnotation.from_json(category)
298+
)
299+
else:
300+
annotations[CATEGORY_TYPE].append(
301+
CategoryPrediction.from_json(category)
302+
)
269303
for multicategory in row[MULTICATEGORY_TYPE]:
270304
multicategory[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
271305
annotations[MULTICATEGORY_TYPE].append(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.15.9"
24+
version = "0.15.10"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/test_autocurate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def model_run(CLIENT):
4343
yield run
4444

4545
response = CLIENT.delete_model(model.id)
46-
assert response == {}
46+
assert "msg" in response
47+
assert response["msg"] == "Model deletion running in an async job"
4748

4849

4950
# @pytest.mark.integration

tests/test_dataset.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -589,9 +589,6 @@ def test_dataset_get_image_indexing_status(CLIENT):
589589
assert resp["embedding_count"] == 170
590590
assert resp["image_count"] == 170
591591
assert "object_count" not in resp
592-
assert round(resp["percent_indexed"], 2) == round(
593-
resp["image_count"] / resp["embedding_count"], 2
594-
)
595592

596593

597594
@pytest.mark.integration
@@ -601,9 +598,6 @@ def test_dataset_get_object_indexing_status(CLIENT):
601598
assert resp["embedding_count"] == 422
602599
assert resp["object_count"] == 423
603600
assert "image_count" not in resp
604-
assert round(resp["percent_indexed"], 2) == round(
605-
resp["object_count"] / resp["embedding_count"], 2
606-
)
607601

608602

609603
@pytest.mark.integration

tests/test_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def test_model_creation_and_listing(CLIENT, dataset):
7878

7979
# Delete the model
8080
CLIENT.delete_model(model.id)
81+
time.sleep(
82+
30
83+
) # model deletion runs async TODO: do a correct job await here instead of sleep
8184
ms = CLIENT.models
8285

8386
assert model not in ms

tests/test_prediction.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def model_run(CLIENT):
111111
yield run
112112

113113
response = CLIENT.delete_model(model.id)
114-
assert response == {}
114+
assert "msg" in response
115+
assert response["msg"] == "Model deletion running in an async job"
115116

116117

117118
@pytest.fixture()
@@ -139,7 +140,8 @@ def scene_category_model_run(CLIENT):
139140
yield run
140141

141142
response = CLIENT.delete_model(model.id)
142-
assert response == {}
143+
assert "msg" in response
144+
assert response["msg"] == "Model deletion running in an async job"
143145

144146

145147
def test_box_pred_upload(model_run):

tests/test_track.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def test_create_mp_with_tracks(CLIENT, dataset_scene):
9292
).issubset(expected_track_reference_ids)
9393

9494
# Cleanup
95-
assert CLIENT.delete_model(model.id) == {}
95+
response = CLIENT.delete_model(model.id) == {}
96+
assert "msg" in response
97+
assert response["msg"] == "Model deletion running in an async job"
9698
dataset_scene.delete_tracks(expected_track_reference_ids)
9799

98100

tests/validate/test_scenario_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ def test_scenario_test_get_tracks(
189189

190190
# Clean
191191
CLIENT.validate.delete_scenario_test(scenario_test.id)
192-
assert CLIENT.delete_model(model.id) == {}
192+
response = CLIENT.delete_model(model.id)
193+
assert "msg" in response
194+
assert response["msg"] == "Model deletion running in an async job"
193195

194196

195197
def test_no_criteria_raises_error(CLIENT, test_slice, annotations):

0 commit comments

Comments
 (0)