Skip to content

Commit 947be07

Browse files
authored
Merge pull request #76 from scaleapi/da/export
Batch export functionality for slice and dataset + point fix
2 parents 58cdcde + e2303c7 commit 947be07

13 files changed

+247
-97
lines changed

nucleus/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
PolygonAnnotation,
6868
Segment,
6969
SegmentationAnnotation,
70+
Point,
7071
)
7172
from .constants import (
7273
ANNOTATION_METADATA_SCHEMA_KEY,

nucleus/annotation.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from dataclasses import dataclass
33
from enum import Enum
4-
from typing import Any, Dict, List, Optional, Sequence, Union
4+
from typing import Dict, List, Optional, Sequence, Union
55
from nucleus.dataset_item import is_local_path
66

77
from .constants import (
@@ -174,11 +174,23 @@ def to_payload(self) -> dict:
174174
}
175175

176176

177-
# TODO: Add Generic type for 2D point
177+
@dataclass
178+
class Point:
179+
x: float
180+
y: float
181+
182+
@classmethod
183+
def from_json(cls, payload: Dict[str, float]):
184+
return cls(payload[X_KEY], payload[Y_KEY])
185+
186+
def to_payload(self) -> dict:
187+
return {X_KEY: self.x, Y_KEY: self.y}
188+
189+
178190
@dataclass
179191
class PolygonAnnotation(Annotation):
180192
label: str
181-
vertices: List[Any]
193+
vertices: List[Point]
182194
reference_id: Optional[str] = None
183195
item_id: Optional[str] = None
184196
annotation_id: Optional[str] = None
@@ -187,28 +199,46 @@ class PolygonAnnotation(Annotation):
187199
def __post_init__(self):
188200
self._check_ids()
189201
self.metadata = self.metadata if self.metadata else {}
202+
if len(self.vertices) > 0:
203+
if not hasattr(self.vertices[0], X_KEY) or not hasattr(
204+
self.vertices[0], "to_payload"
205+
):
206+
try:
207+
self.vertices = [
208+
Point(x=vertex[X_KEY], y=vertex[Y_KEY])
209+
for vertex in self.vertices
210+
]
211+
except KeyError as ke:
212+
raise ValueError(
213+
"Use a point object to pass in vertices. For example, vertices=[nucleus.Point(x=1, y=2)]"
214+
) from ke
190215

191216
@classmethod
192217
def from_json(cls, payload: dict):
193218
geometry = payload.get(GEOMETRY_KEY, {})
194219
return cls(
195220
label=payload.get(LABEL_KEY, 0),
196-
vertices=geometry.get(VERTICES_KEY, []),
221+
vertices=[
222+
Point.from_json(_) for _ in geometry.get(VERTICES_KEY, [])
223+
],
197224
reference_id=payload.get(REFERENCE_ID_KEY, None),
198225
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
199226
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
200227
metadata=payload.get(METADATA_KEY, {}),
201228
)
202229

203230
def to_payload(self) -> dict:
204-
return {
231+
payload = {
205232
LABEL_KEY: self.label,
206233
TYPE_KEY: POLYGON_TYPE,
207-
GEOMETRY_KEY: {VERTICES_KEY: self.vertices},
234+
GEOMETRY_KEY: {
235+
VERTICES_KEY: [_.to_payload() for _ in self.vertices]
236+
},
208237
REFERENCE_ID_KEY: self.reference_id,
209238
ANNOTATION_ID_KEY: self.annotation_id,
210239
METADATA_KEY: self.metadata,
211240
}
241+
return payload
212242

213243

214244
def check_all_annotation_paths_remote(

nucleus/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
1111
ANNOTATION_UPDATE_KEY = "update"
1212
AUTOTAGS_KEY = "autotags"
13-
13+
EXPORTED_ROWS = "exportedRows"
1414
CLASS_PDF_KEY = "class_pdf"
1515
CONFIDENCE_KEY = "confidence"
1616
DATASET_ID_KEY = "dataset_id"

nucleus/dataset.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from nucleus.job import AsyncJob
66
from nucleus.utils import (
7+
convert_export_payload,
78
format_dataset_item_response,
89
serialize_and_write_to_presigned_url,
910
)
@@ -16,6 +17,7 @@
1617
DATASET_NAME_KEY,
1718
DATASET_SLICES_KEY,
1819
DEFAULT_ANNOTATION_UPDATE_MODE,
20+
EXPORTED_ROWS,
1921
JOB_ID_KEY,
2022
NAME_KEY,
2123
REFERENCE_IDS_KEY,
@@ -327,3 +329,23 @@ def delete_custom_index(self):
327329

328330
def check_index_status(self, job_id: str):
329331
return self._client.check_index_status(job_id)
332+
333+
def items_and_annotations(
334+
self,
335+
) -> List[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
336+
"""Returns a list of all DatasetItems and Annotations in this slice.
337+
338+
Returns:
339+
A list, where each item is a dict with two keys representing a row
340+
in the dataset.
341+
* One value in the dict is the DatasetItem, containing a reference to the
342+
item that was annotated.
343+
* The other value is a dictionary containing all the annotations for this
344+
dataset item, sorted by annotation type.
345+
"""
346+
api_payload = self._client.make_request(
347+
payload=None,
348+
route=f"dataset/{self.id}/exportForTraining",
349+
requests_command=requests.get,
350+
)
351+
return convert_export_payload(api_payload[EXPORTED_ROWS])

nucleus/prediction.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Dict, Optional, List, Any
1+
from typing import Dict, Optional, List
22
from .annotation import (
33
BoxAnnotation,
4+
Point,
45
PolygonAnnotation,
56
Segment,
67
SegmentationAnnotation,
@@ -102,7 +103,7 @@ class PolygonPrediction(PolygonAnnotation):
102103
def __init__(
103104
self,
104105
label: str,
105-
vertices: List[Any],
106+
vertices: List[Point],
106107
reference_id: Optional[str] = None,
107108
item_id: Optional[str] = None,
108109
confidence: Optional[float] = None,
@@ -135,7 +136,9 @@ def from_json(cls, payload: dict):
135136
geometry = payload.get(GEOMETRY_KEY, {})
136137
return cls(
137138
label=payload.get(LABEL_KEY, 0),
138-
vertices=geometry.get(VERTICES_KEY, []),
139+
vertices=[
140+
Point.from_json(_) for _ in geometry.get(VERTICES_KEY, [])
141+
],
139142
reference_id=payload.get(REFERENCE_ID_KEY, None),
140143
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
141144
confidence=payload.get(CONFIDENCE_KEY, None),

nucleus/slice.py

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import Dict, List, Iterable, Set, Tuple, Optional, Union
2-
from nucleus.dataset_item import DatasetItem
1+
from typing import Dict, Iterable, List, Set, Tuple, Union
2+
3+
import requests
4+
35
from nucleus.annotation import Annotation
4-
from nucleus.utils import format_dataset_item_response
6+
from nucleus.dataset_item import DatasetItem
57
from nucleus.job import AsyncJob
6-
7-
from .constants import DEFAULT_ANNOTATION_UPDATE_MODE
8+
from nucleus.utils import convert_export_payload, format_dataset_item_response
9+
from nucleus.constants import EXPORTED_ROWS
810

911

1012
class Slice:
@@ -109,42 +111,12 @@ def items_and_annotations(
109111
* The other value is a dictionary containing all the annotations for this
110112
dataset item, sorted by annotation type.
111113
"""
112-
return list(self.items_and_annotation_generator())
113-
114-
def annotate(
115-
self,
116-
annotations: List[Annotation],
117-
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
118-
batch_size: int = 5000,
119-
strict=True,
120-
):
121-
"""Update annotations within this slice.
122-
123-
Args:
124-
annotations: List of annotations to upload
125-
batch_size: How many annotations to send per request.
126-
strict: Whether to first check that the annotations belong to this slice.
127-
Set to false to avoid this check and speed up upload.
128-
"""
129-
if strict:
130-
(
131-
annotations_are_in_slice,
132-
item_ids_not_found_in_slice,
133-
reference_ids_not_found_in_slice,
134-
) = check_annotations_are_in_slice(annotations, self)
135-
if not annotations_are_in_slice:
136-
message = "Not all annotations are in this slice.\n"
137-
if item_ids_not_found_in_slice:
138-
message += f"Item ids not found in slice: {item_ids_not_found_in_slice} \n"
139-
if reference_ids_not_found_in_slice:
140-
message += f"Reference ids not found in slice: {reference_ids_not_found_in_slice}"
141-
raise ValueError(message)
142-
self._client.annotate_dataset(
143-
dataset_id=self.dataset_id,
144-
annotations=annotations,
145-
update=update,
146-
batch_size=batch_size,
114+
api_payload = self._client.make_request(
115+
payload=None,
116+
route=f"slice/{self.slice_id}/exportForTraining",
117+
requests_command=requests.get,
147118
)
119+
return convert_export_payload(api_payload[EXPORTED_ROWS])
148120

149121
def send_to_labeling(self, project_id: str):
150122
response = self._client.make_request(

nucleus/utils.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,29 @@
11
"""Shared stateless utility function library"""
22

3-
3+
from collections import defaultdict
44
import io
55
import uuid
66
from typing import IO, Dict, List, Sequence, Union
77

88
import requests
99
from requests.models import HTTPError
1010

11-
from nucleus.annotation import Annotation
12-
13-
from .constants import ANNOTATION_TYPES, ANNOTATIONS_KEY, ITEM_KEY
11+
from nucleus.annotation import (
12+
Annotation,
13+
BoxAnnotation,
14+
PolygonAnnotation,
15+
SegmentationAnnotation,
16+
)
17+
18+
from .constants import (
19+
ANNOTATION_TYPES,
20+
ANNOTATIONS_KEY,
21+
BOX_TYPE,
22+
ITEM_KEY,
23+
POLYGON_TYPE,
24+
REFERENCE_ID_KEY,
25+
SEGMENTATION_TYPE,
26+
)
1427
from .dataset_item import DatasetItem
1528
from .prediction import BoxPrediction, PolygonPrediction
1629

@@ -73,6 +86,31 @@ def format_dataset_item_response(response: dict) -> dict:
7386
}
7487

7588

89+
def convert_export_payload(api_payload):
90+
return_payload = []
91+
for row in api_payload:
92+
return_payload_row = {}
93+
return_payload_row[ITEM_KEY] = DatasetItem.from_json(row[ITEM_KEY])
94+
annotations = defaultdict(list)
95+
if row.get(SEGMENTATION_TYPE) is not None:
96+
segmentation = row[SEGMENTATION_TYPE]
97+
segmentation[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
98+
annotations[SEGMENTATION_TYPE] = SegmentationAnnotation.from_json(
99+
segmentation
100+
)
101+
for polygon in row[POLYGON_TYPE]:
102+
polygon[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
103+
annotations[POLYGON_TYPE].append(
104+
PolygonAnnotation.from_json(polygon)
105+
)
106+
for box in row[BOX_TYPE]:
107+
box[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
108+
annotations[BOX_TYPE].append(BoxAnnotation.from_json(box))
109+
return_payload_row[ANNOTATIONS_KEY] = annotations
110+
return_payload.append(return_payload_row)
111+
return return_payload
112+
113+
76114
def serialize_and_write(
77115
upload_units: Sequence[Union[DatasetItem, Annotation]], file_pointer
78116
):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.1.9"
24+
version = "0.1.10"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/helpers.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,15 @@ def reference_id_from_url(url):
5757
TEST_POLYGON_ANNOTATIONS = [
5858
{
5959
"label": f"[Pytest] Polygon Annotation ${i}",
60-
"vertices": [
61-
{
62-
"x": 50 + i * 10 + j,
63-
"y": 60 + i * 10 + j,
64-
}
65-
for j in range(3)
66-
],
60+
"geometry": {
61+
"vertices": [
62+
{
63+
"x": 50 + i * 10 + j,
64+
"y": 60 + i * 10 + j,
65+
}
66+
for j in range(3)
67+
],
68+
},
6769
"reference_id": reference_id_from_url(TEST_IMG_URLS[i]),
6870
"annotation_id": f"[Pytest] Polygon Annotation Annotation Id{i}",
6971
}
@@ -149,10 +151,10 @@ def assert_polygon_annotation_matches_dict(
149151
annotation_instance.annotation_id == annotation_dict["annotation_id"]
150152
)
151153
for instance_pt, dict_pt in zip(
152-
annotation_instance.vertices, annotation_dict["vertices"]
154+
annotation_instance.vertices, annotation_dict["geometry"]["vertices"]
153155
):
154-
assert instance_pt["x"] == dict_pt["x"]
155-
assert instance_pt["y"] == dict_pt["y"]
156+
assert instance_pt.x == dict_pt["x"]
157+
assert instance_pt.y == dict_pt["y"]
156158

157159

158160
def assert_segmentation_annotation_matches_dict(

0 commit comments

Comments
 (0)