Skip to content

Commit 91b8743

Browse files
authored
add method for exporting predictions (#300)
* add export predictions * cleanup * finish test coverage * address PR comments * bump version number + describe changes in changelog
1 parent 879a2db commit 91b8743

File tree

8 files changed

+114
-6
lines changed

8 files changed

+114
-6
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
9+
## [0.11.1](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.11.1) - 2022-05-19
10+
11+
### Added
12+
13+
- Exporting model predictions from a slice
14+
815
## [0.11.0](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.11.0) - 2022-05-13
916

1017
### Added

nucleus/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,6 @@ def make_request(
837837
Returns:
838838
Response payload as JSON dict.
839839
"""
840-
print(payload, route)
841840
if payload is None:
842841
payload = {}
843842
if requests_command is requests.get:

nucleus/annotation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,19 @@ def to_payload(self) -> dict:
180180
EMBEDDING_VECTOR_KEY: self.embedding_vector,
181181
}
182182

183+
def __eq__(self, other):
184+
return (
185+
self.label == other.label
186+
and self.x == other.x
187+
and self.y == other.y
188+
and self.width == other.width
189+
and self.height == other.height
190+
and self.reference_id == other.reference_id
191+
and self.annotation_id == other.annotation_id
192+
and sorted(self.metadata.items()) == sorted(other.metadata.items())
193+
and self.embedding_vector == other.embedding_vector
194+
)
195+
183196

184197
@dataclass
185198
class Point:

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
POINTCLOUD_URL_KEY = "pointcloud_url"
106106
POSITION_KEY = "position"
107107
PREDICTIONS_IGNORED_KEY = "predictions_ignored"
108+
PREDICTIONS_KEY = "predictions"
108109
PREDICTIONS_PROCESSED_KEY = "predictions_processed"
109110
REFERENCE_IDS_KEY = "reference_ids"
110111
REFERENCE_ID_KEY = "reference_id"

nucleus/slice.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,37 @@ def items_and_annotations(
234234
)
235235
return convert_export_payload(api_payload[EXPORTED_ROWS])
236236

237+
def export_predictions(
238+
self, model
239+
) -> List[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
240+
"""Provides a list of all DatasetItems and Predictions in the Slice for the given Model.
241+
242+
Parameters:
243+
model (Model): the nucleus model objects representing the model for which to export predictions.
244+
245+
Returns:
246+
List where each element is a dict containing the DatasetItem
247+
and all of its associated Predictions, grouped by type (e.g. box).
248+
::
249+
250+
List[{
251+
"item": DatasetItem,
252+
"predicions": {
253+
"box": List[BoxAnnotation],
254+
"polygon": List[PolygonAnnotation],
255+
"cuboid": List[CuboidAnnotation],
256+
"segmentation": List[SegmentationAnnotation],
257+
"category": List[CategoryAnnotation],
258+
}
259+
}]
260+
"""
261+
api_payload = self._client.make_request(
262+
payload=None,
263+
route=f"slice/{self.id}/{model.id}/exportForTraining",
264+
requests_command=requests.get,
265+
)
266+
return convert_export_payload(api_payload[EXPORTED_ROWS], True)
267+
237268
def send_to_labeling(self, project_id: str):
238269
"""Send items in the Slice as tasks to a Scale labeling project.
239270

nucleus/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
PAGE_SIZE,
3838
PAGE_TOKEN,
3939
POLYGON_TYPE,
40+
PREDICTIONS_KEY,
4041
REFERENCE_ID_KEY,
4142
SEGMENTATION_TYPE,
4243
)
@@ -187,7 +188,7 @@ def format_dataset_item_response(response: dict) -> dict:
187188
}
188189

189190

190-
def convert_export_payload(api_payload):
191+
def convert_export_payload(api_payload, has_predictions: bool = False):
191192
"""Helper function to convert raw JSON to API objects
192193
193194
Args:
@@ -237,7 +238,9 @@ def convert_export_payload(api_payload):
237238
annotations[MULTICATEGORY_TYPE].append(
238239
MultiCategoryAnnotation.from_json(multicategory)
239240
)
240-
return_payload_row[ANNOTATIONS_KEY] = annotations
241+
return_payload_row[
242+
ANNOTATIONS_KEY if not has_predictions else PREDICTIONS_KEY
243+
] = annotations
241244
return_payload.append(return_payload_row)
242245
return return_payload
243246

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.11.0"
24+
version = "0.11.1"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/test_slice.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,38 @@
33
import pytest
44
import requests
55

6-
from nucleus import BoxAnnotation, Dataset, NucleusClient, Slice
7-
from nucleus.constants import ANNOTATIONS_KEY, BOX_TYPE, ITEM_KEY
6+
from nucleus import BoxAnnotation, BoxPrediction, Dataset, NucleusClient, Slice
7+
from nucleus.constants import (
8+
ANNOTATIONS_KEY,
9+
BOX_TYPE,
10+
ITEM_KEY,
11+
PREDICTIONS_KEY,
12+
)
813
from nucleus.job import AsyncJob
914

1015
from .helpers import (
1116
TEST_BOX_ANNOTATIONS,
17+
TEST_BOX_PREDICTIONS,
1218
TEST_PROJECT_ID,
1319
TEST_SLICE_NAME,
1420
get_uuid,
1521
)
1622

1723

24+
@pytest.fixture()
25+
def slc(CLIENT, dataset):
26+
slice_ref_ids = [item.reference_id for item in dataset.items[:1]]
27+
# Slice creation
28+
slc = dataset.create_slice(
29+
name=TEST_SLICE_NAME,
30+
reference_ids=slice_ref_ids,
31+
)
32+
33+
yield slc
34+
35+
CLIENT.delete_slice(slc.id)
36+
37+
1838
def test_reprs():
1939
# Have to define here in order to have access to all relevant objects
2040
def test_repr(test_object: any):
@@ -89,6 +109,40 @@ def get_expected_item(reference_id):
89109
] == get_expected_box_annotation(reference_id)
90110

91111

112+
def test_slice_create_and_prediction_export(dataset, slc, model):
113+
# Dataset upload
114+
ds_items = dataset.items
115+
116+
predictions = [
117+
BoxPrediction(**pred_raw) for pred_raw in TEST_BOX_PREDICTIONS
118+
]
119+
response = dataset.upload_predictions(model, predictions)
120+
121+
assert response
122+
123+
slice_reference_ids = [item.reference_id for item in slc.items]
124+
125+
def get_expected_box_prediction(reference_id):
126+
for prediction in predictions:
127+
if prediction.reference_id == reference_id:
128+
return prediction
129+
130+
def get_expected_item(reference_id):
131+
if reference_id not in slice_reference_ids:
132+
raise ValueError("Got results outside the slice")
133+
for item in ds_items:
134+
if item.reference_id == reference_id:
135+
return item
136+
137+
exported = slc.export_predictions(model)
138+
for row in exported:
139+
reference_id = row[ITEM_KEY].reference_id
140+
assert row[ITEM_KEY] == get_expected_item(reference_id)
141+
assert row[PREDICTIONS_KEY][BOX_TYPE][
142+
0
143+
] == get_expected_box_prediction(reference_id)
144+
145+
92146
def test_slice_append(dataset):
93147
ds_items = dataset.items
94148

0 commit comments

Comments
 (0)