Skip to content

Commit 0ddc3f5

Browse files
authored
paginate export methods (#341)
* paginate * docstring * bump semver and changelog * lint
1 parent e18f313 commit 0ddc3f5

File tree

6 files changed

+65
-14
lines changed

6 files changed

+65
-14
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.14.11](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.12) - 2022-08-05
9+
10+
### Added
11+
- Added auto-paginated `Slice.export_predictions_generator`
12+
### Fixed
13+
- Change `{Dataset,Slice}.items_and_annotation_generator` to work with improved paginate endpoint
14+
815
## [0.14.11](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.11) - 2022-07-20
916

1017
### Fixed

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
AUTOTAG_SCORE_THRESHOLD = "score_threshold"
2929
EXPORTED_ROWS = "exportedRows"
3030
EXPORTED_SCALE_TASK_INFO_ROWS = "exportedScaleTaskInfoRows"
31+
EXPORT_FOR_TRAINING_KEY = "data"
3132
CAMERA_MODEL_KEY = "camera_model"
3233
CAMERA_PARAMS_KEY = "camera_params"
3334
CLASS_PDF_KEY = "class_pdf"

nucleus/dataset.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
DEFAULT_ANNOTATION_UPDATE_MODE,
3434
EMBEDDING_DIMENSION_KEY,
3535
EMBEDDINGS_URL_KEY,
36+
EXPORT_FOR_TRAINING_KEY,
3637
EXPORTED_ROWS,
3738
FRAME_RATE_KEY,
3839
ITEMS_KEY,
@@ -1250,8 +1251,15 @@ def items_and_annotation_generator(
12501251
}
12511252
}]
12521253
"""
1253-
for item in self.items_generator():
1254-
yield self.refloc(reference_id=item.reference_id)
1254+
json_generator = paginate_generator(
1255+
client=self._client,
1256+
endpoint=f"dataset/{self.id}/exportForTrainingPage",
1257+
result_key=EXPORT_FOR_TRAINING_KEY,
1258+
page_size=100000,
1259+
)
1260+
for data in json_generator:
1261+
for ia in convert_export_payload([data], has_predictions=False):
1262+
yield ia
12551263

12561264
def export_embeddings(
12571265
self,

nucleus/slice.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
import requests
55

66
from nucleus.annotation import Annotation
7-
from nucleus.constants import EXPORTED_ROWS, ITEMS_KEY
7+
from nucleus.constants import EXPORT_FOR_TRAINING_KEY, EXPORTED_ROWS, ITEMS_KEY
88
from nucleus.dataset_item import DatasetItem
99
from nucleus.errors import NucleusAPIError
1010
from nucleus.job import AsyncJob
1111
from nucleus.utils import (
1212
KeyErrorDict,
1313
convert_export_payload,
14-
format_dataset_item_response,
1514
format_scale_task_info_response,
1615
paginate_generator,
1716
)
@@ -203,13 +202,15 @@ def items_and_annotation_generator(
203202
}
204203
}]
205204
"""
206-
for item in self.items_generator():
207-
yield format_dataset_item_response(
208-
self._client.dataitem_ref_id(
209-
dataset_id=self.dataset_id,
210-
reference_id=item.reference_id,
211-
)
212-
)
205+
json_generator = paginate_generator(
206+
client=self._client,
207+
endpoint=f"slice/{self.id}/exportForTrainingPage",
208+
result_key=EXPORT_FOR_TRAINING_KEY,
209+
page_size=100000,
210+
)
211+
for data in json_generator:
212+
for ia in convert_export_payload([data], has_predictions=False):
213+
yield ia
213214

214215
def items_and_annotations(
215216
self,
@@ -256,7 +257,7 @@ def export_predictions(
256257
257258
List[{
258259
"item": DatasetItem,
259-
"predicions": {
260+
"predictions": {
260261
"box": List[BoxAnnotation],
261262
"polygon": List[PolygonAnnotation],
262263
"cuboid": List[CuboidAnnotation],
@@ -272,6 +273,40 @@ def export_predictions(
272273
)
273274
return convert_export_payload(api_payload[EXPORTED_ROWS], True)
274275

276+
def export_predictions_generator(
277+
self, model
278+
) -> Iterable[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
279+
"""Provides a list of all DatasetItems and Predictions in the Slice for the given Model.
280+
281+
Parameters:
282+
model (Model): the nucleus model objects representing the model for which to export predictions.
283+
284+
Returns:
285+
Iterable where each element is a dict containing the DatasetItem
286+
and all of its associated Predictions, grouped by type (e.g. box).
287+
::
288+
289+
List[{
290+
"item": DatasetItem,
291+
"predictions": {
292+
"box": List[BoxAnnotation],
293+
"polygon": List[PolygonAnnotation],
294+
"cuboid": List[CuboidAnnotation],
295+
"segmentation": List[SegmentationAnnotation],
296+
"category": List[CategoryAnnotation],
297+
}
298+
}]
299+
"""
300+
json_generator = paginate_generator(
301+
client=self._client,
302+
endpoint=f"slice/{self.id}/{model.id}/exportForTrainingPage",
303+
result_key=EXPORT_FOR_TRAINING_KEY,
304+
page_size=100000,
305+
)
306+
for data in json_generator:
307+
for ip in convert_export_payload([data], has_predictions=True):
308+
yield ip
309+
275310
def export_scale_task_info(self):
276311
"""Fetches info for all linked Scale tasks of items/scenes in the slice.
277312

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.14.11"
24+
version = "0.14.12"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/test_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def sort_labelmap(segmentation_annotation):
533533
assert row[ITEM_KEY] == ds_items[0]
534534
assert row[ANNOTATIONS_KEY][BOX_TYPE][0] == box_annotation
535535
assert sort_labelmap(
536-
row[ANNOTATIONS_KEY][SEGMENTATION_TYPE][0]
536+
row[ANNOTATIONS_KEY][SEGMENTATION_TYPE]
537537
) == sort_labelmap(clear_fields(segmentation_annotation))
538538
assert row[ANNOTATIONS_KEY][POLYGON_TYPE][0] == polygon_annotation
539539
assert row[ANNOTATIONS_KEY][CATEGORY_TYPE][0] == category_annotation

0 commit comments

Comments
 (0)