Skip to content

Commit 2e515a4

Browse files
authored
Merge pull request #57 from scaleapi/da/slice-support
Da/slice support
2 parents 4d09675 + 342d12f commit 2e515a4

File tree

5 files changed

+308
-98
lines changed

5 files changed

+308
-98
lines changed

nucleus/dataset.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import List, Dict, Any, Optional, Union
1+
from typing import List, Dict, Any, Optional
2+
3+
from nucleus.utils import format_dataset_item_response
24
from .dataset_item import DatasetItem
35
from .annotation import (
46
Annotation,
5-
BoxAnnotation,
6-
PolygonAnnotation,
77
)
88
from .constants import (
99
DATASET_NAME_KEY,
@@ -13,10 +13,7 @@
1313
DATASET_ITEM_IDS_KEY,
1414
REFERENCE_IDS_KEY,
1515
NAME_KEY,
16-
ITEM_KEY,
1716
DEFAULT_ANNOTATION_UPDATE_MODE,
18-
ANNOTATIONS_KEY,
19-
ANNOTATION_TYPES,
2017
)
2118
from .payload_constructor import construct_model_run_creation_payload
2219

@@ -109,7 +106,7 @@ def create_model_run(
109106

110107
def annotate(
111108
self,
112-
annotations: List[Union[BoxAnnotation, PolygonAnnotation]],
109+
annotations: List[Annotation],
113110
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
114111
batch_size: int = 5000,
115112
) -> dict:
@@ -179,7 +176,7 @@ def iloc(self, i: int) -> dict:
179176
}
180177
"""
181178
response = self._client.dataitem_iloc(self.id, i)
182-
return self._format_dataset_item_response(response)
179+
return format_dataset_item_response(response)
183180

184181
def refloc(self, reference_id: str) -> dict:
185182
"""
@@ -192,7 +189,7 @@ def refloc(self, reference_id: str) -> dict:
192189
}
193190
"""
194191
response = self._client.dataitem_ref_id(self.id, reference_id)
195-
return self._format_dataset_item_response(response)
192+
return format_dataset_item_response(response)
196193

197194
def loc(self, dataset_item_id: str) -> dict:
198195
"""
@@ -205,7 +202,7 @@ def loc(self, dataset_item_id: str) -> dict:
205202
}
206203
"""
207204
response = self._client.dataitem_loc(self.id, dataset_item_id)
208-
return self._format_dataset_item_response(response)
205+
return format_dataset_item_response(response)
209206

210207
def create_slice(
211208
self,
@@ -247,25 +244,6 @@ def delete_item(self, item_id: str = None, reference_id: str = None):
247244
def list_autotags(self):
248245
return self._client.list_autotags(self.id)
249246

250-
def _format_dataset_item_response(self, response: dict) -> dict:
251-
item = response.get(ITEM_KEY, None)
252-
annotation_payload = response.get(ANNOTATIONS_KEY, {})
253-
if not item or not annotation_payload:
254-
# An error occured
255-
return response
256-
257-
annotation_response = {}
258-
for annotation_type in ANNOTATION_TYPES:
259-
if annotation_type in annotation_payload:
260-
annotation_response[annotation_type] = [
261-
Annotation.from_json(ann)
262-
for ann in annotation_payload[annotation_type]
263-
]
264-
return {
265-
ITEM_KEY: DatasetItem.from_json(item),
266-
ANNOTATIONS_KEY: annotation_response,
267-
}
268-
269247
def create_custom_index(self, embeddings_url: str):
270248
return self._client.create_custom_index(self.id, embeddings_url)
271249

nucleus/slice.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from typing import List
1+
from typing import Dict, List, Iterable, Set, Tuple, Optional, Union
2+
from nucleus.dataset_item import DatasetItem
3+
from nucleus.annotation import Annotation
4+
from nucleus.utils import format_dataset_item_response
5+
6+
from .constants import DEFAULT_ANNOTATION_UPDATE_MODE
27

38

49
class Slice:
@@ -9,6 +14,7 @@ class Slice:
914
def __init__(self, slice_id: str, client):
1015
self.slice_id = slice_id
1116
self._client = client
17+
self._dataset_id = None
1218

1319
def __repr__(self):
1420
return f"Slice(slice_id='{self.slice_id}', client={self._client})"
@@ -19,6 +25,13 @@ def __eq__(self, other):
1925
return True
2026
return False
2127

28+
@property
29+
def dataset_id(self):
30+
"""The id of the dataset this slice belongs to."""
31+
if self._dataset_id is None:
32+
self.info()
33+
return self._dataset_id
34+
2235
def info(self) -> dict:
2336
"""
2437
This endpoint provides information about specified slice.
@@ -30,7 +43,9 @@ def info(self) -> dict:
3043
"dataset_items",
3144
}
3245
"""
33-
return self._client.slice_info(self.slice_id)
46+
info = self._client.slice_info(self.slice_id)
47+
self._dataset_id = info["dataset_id"]
48+
return info
3449

3550
def append(
3651
self,
@@ -57,3 +72,118 @@ def append(
5772
reference_ids=reference_ids,
5873
)
5974
return response
75+
76+
def items_and_annotation_generator(
77+
self,
78+
) -> Iterable[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
79+
"""Returns an iterable of all DatasetItems and Annotations in this slice.
80+
81+
Returns:
82+
An iterable, where each item is a dict with two keys representing a row
83+
in the dataset.
84+
* One value in the dict is the DatasetItem, containing a reference to the
85+
item that was annotated, for example an image_url.
86+
* The other value is a dictionary containing all the annotations for this
87+
dataset item, sorted by annotation type.
88+
"""
89+
info = self.info()
90+
for item_metadata in info["dataset_items"]:
91+
yield format_dataset_item_response(
92+
self._client.dataitem_loc(
93+
dataset_id=info["dataset_id"],
94+
dataset_item_id=item_metadata["id"],
95+
)
96+
)
97+
98+
def items_and_annotations(
99+
self,
100+
) -> List[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
101+
"""Returns a list of all DatasetItems and Annotations in this slice.
102+
103+
Returns:
104+
A list, where each item is a dict with two keys representing a row
105+
in the dataset.
106+
* One value in the dict is the DatasetItem, containing a reference to the
107+
item that was annotated.
108+
* The other value is a dictionary containing all the annotations for this
109+
dataset item, sorted by annotation type.
110+
"""
111+
return list(self.items_and_annotation_generator())
112+
113+
def annotate(
114+
self,
115+
annotations: List[Annotation],
116+
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
117+
batch_size: int = 5000,
118+
strict=True,
119+
):
120+
"""Update annotations within this slice.
121+
122+
Args:
123+
annotations: List of annotations to upload
124+
batch_size: How many annotations to send per request.
125+
strict: Whether to first check that the annotations belong to this slice.
126+
Set to false to avoid this check and speed up upload.
127+
"""
128+
if strict:
129+
(
130+
annotations_are_in_slice,
131+
item_ids_not_found_in_slice,
132+
reference_ids_not_found_in_slice,
133+
) = check_annotations_are_in_slice(annotations, self)
134+
if not annotations_are_in_slice:
135+
message = "Not all annotations are in this slice.\n"
136+
if item_ids_not_found_in_slice:
137+
message += f"Item ids not found in slice: {item_ids_not_found_in_slice} \n"
138+
if reference_ids_not_found_in_slice:
139+
message += f"Reference ids not found in slice: {reference_ids_not_found_in_slice}"
140+
raise ValueError(message)
141+
self._client.annotate_dataset(
142+
dataset_id=self.dataset_id,
143+
annotations=annotations,
144+
update=update,
145+
batch_size=batch_size,
146+
)
147+
148+
149+
def check_annotations_are_in_slice(
150+
annotations: List[Annotation], slice_to_check: Slice
151+
) -> Tuple[bool, Set[str], Set[str]]:
152+
"""Check membership of the annotation targets within this slice.
153+
154+
annotations: Annnotations with ids referring to targets.
155+
slice: The slice to check against.
156+
157+
158+
Returns:
159+
A tuple, where the first element is true/false whether the annotations are all
160+
in the slice.
161+
The second element is the list of item_ids not in the slice.
162+
The third element is the list of ref_ids not in the slice.
163+
"""
164+
info = slice_to_check.info()
165+
166+
item_ids_not_found_in_slice = {
167+
annotation.item_id
168+
for annotation in annotations
169+
if annotation.item_id is not None
170+
}.difference(
171+
{item_metadata["id"] for item_metadata in info["dataset_items"]}
172+
)
173+
reference_ids_not_found_in_slice = {
174+
annotation.reference_id
175+
for annotation in annotations
176+
if annotation.reference_id is not None
177+
}.difference(
178+
{item_metadata["ref_id"] for item_metadata in info["dataset_items"]}
179+
)
180+
if item_ids_not_found_in_slice or reference_ids_not_found_in_slice:
181+
annotations_are_in_slice = False
182+
else:
183+
annotations_are_in_slice = True
184+
185+
return (
186+
annotations_are_in_slice,
187+
item_ids_not_found_in_slice,
188+
reference_ids_not_found_in_slice,
189+
)

nucleus/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
1+
"""Shared stateless utility function library"""
2+
3+
14
from typing import List, Union, Dict
25

6+
from nucleus.annotation import Annotation
37
from .dataset_item import DatasetItem
48
from .prediction import BoxPrediction, PolygonPrediction
59

10+
from .constants import (
11+
ITEM_KEY,
12+
ANNOTATIONS_KEY,
13+
ANNOTATION_TYPES,
14+
)
15+
616

717
def _get_all_field_values(metadata_list: List[dict], key: str):
818
return {metadata[key] for metadata in metadata_list if key in metadata}
@@ -34,3 +44,29 @@ def suggest_metadata_schema(
3444
entry["type"] = "text"
3545
schema[key] = entry
3646
return schema
47+
48+
49+
def format_dataset_item_response(response: dict) -> dict:
50+
"""Format the raw client response into api objects."""
51+
if ANNOTATIONS_KEY not in response:
52+
raise ValueError(
53+
f"Server response was missing the annotation key: {response}"
54+
)
55+
if ITEM_KEY not in response:
56+
raise ValueError(
57+
f"Server response was missing the item key: {response}"
58+
)
59+
item = response[ITEM_KEY]
60+
annotation_payload = response[ANNOTATIONS_KEY]
61+
62+
annotation_response = {}
63+
for annotation_type in ANNOTATION_TYPES:
64+
if annotation_type in annotation_payload:
65+
annotation_response[annotation_type] = [
66+
Annotation.from_json(ann)
67+
for ann in annotation_payload[annotation_type]
68+
]
69+
return {
70+
ITEM_KEY: DatasetItem.from_json(item),
71+
ANNOTATIONS_KEY: annotation_response,
72+
}

tests/test_dataset.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -127,69 +127,3 @@ def test_dataset_list_autotags(CLIENT, dataset):
127127
# List of Autotags should be empty
128128
autotag_response = CLIENT.list_autotags(dataset.id)
129129
assert autotag_response == []
130-
131-
132-
def test_slice_create_and_delete_and_list(dataset):
133-
# Dataset upload
134-
ds_items = []
135-
for url in TEST_IMG_URLS:
136-
ds_items.append(
137-
DatasetItem(
138-
image_location=url,
139-
reference_id=reference_id_from_url(url),
140-
)
141-
)
142-
response = dataset.append(ds_items)
143-
assert ERROR_PAYLOAD not in response.json()
144-
145-
# Slice creation
146-
slc = dataset.create_slice(
147-
name=TEST_SLICE_NAME,
148-
reference_ids=[item.reference_id for item in ds_items[:2]],
149-
)
150-
151-
dataset_slices = dataset.slices
152-
assert len(dataset_slices) == 1
153-
assert slc.slice_id == dataset_slices[0]
154-
155-
response = slc.info()
156-
assert response["name"] == TEST_SLICE_NAME
157-
assert response["dataset_id"] == dataset.id
158-
assert len(response["dataset_items"]) == 2
159-
for item in ds_items[:2]:
160-
assert (
161-
item.reference_id == response["dataset_items"][0]["ref_id"]
162-
or item.reference_id == response["dataset_items"][1]["ref_id"]
163-
)
164-
165-
166-
def test_slice_append(dataset):
167-
# Dataset upload
168-
ds_items = []
169-
for url in TEST_IMG_URLS:
170-
ds_items.append(
171-
DatasetItem(
172-
image_location=url,
173-
reference_id=reference_id_from_url(url),
174-
)
175-
)
176-
response = dataset.append(ds_items)
177-
assert ERROR_PAYLOAD not in response.json()
178-
179-
# Slice creation
180-
slc = dataset.create_slice(
181-
name=TEST_SLICE_NAME,
182-
reference_ids=[ds_items[0].reference_id],
183-
)
184-
185-
# Insert duplicate first item
186-
slc.append(reference_ids=[item.reference_id for item in ds_items[:3]])
187-
188-
response = slc.info()
189-
assert len(response["dataset_items"]) == 3
190-
for item in ds_items[:3]:
191-
assert (
192-
item.reference_id == response["dataset_items"][0]["ref_id"]
193-
or item.reference_id == response["dataset_items"][1]["ref_id"]
194-
or item.reference_id == response["dataset_items"][2]["ref_id"]
195-
)

0 commit comments

Comments
 (0)