Skip to content

Commit d2f3bd9

Browse files
committed
Passes dataset test
1 parent 6386774 commit d2f3bd9

File tree

4 files changed

+90
-66
lines changed

4 files changed

+90
-66
lines changed

nucleus/dataset.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from typing import List, Dict, Any, Optional
2+
3+
from nucleus.utils import format_dataset_item_response
24
from .dataset_item import DatasetItem
35
from .annotation import (
46
Annotation,
@@ -11,10 +13,7 @@
1113
DATASET_ITEM_IDS_KEY,
1214
REFERENCE_IDS_KEY,
1315
NAME_KEY,
14-
ITEM_KEY,
1516
DEFAULT_ANNOTATION_UPDATE_MODE,
16-
ANNOTATIONS_KEY,
17-
ANNOTATION_TYPES,
1817
)
1918
from .payload_constructor import construct_model_run_creation_payload
2019

@@ -177,7 +176,7 @@ def iloc(self, i: int) -> dict:
177176
}
178177
"""
179178
response = self._client.dataitem_iloc(self.id, i)
180-
return self._format_dataset_item_response(response)
179+
return format_dataset_item_response(response)
181180

182181
def refloc(self, reference_id: str) -> dict:
183182
"""
@@ -190,7 +189,7 @@ def refloc(self, reference_id: str) -> dict:
190189
}
191190
"""
192191
response = self._client.dataitem_ref_id(self.id, reference_id)
193-
return self._format_dataset_item_response(response)
192+
return format_dataset_item_response(response)
194193

195194
def loc(self, dataset_item_id: str) -> dict:
196195
"""
@@ -203,7 +202,7 @@ def loc(self, dataset_item_id: str) -> dict:
203202
}
204203
"""
205204
response = self._client.dataitem_loc(self.id, dataset_item_id)
206-
return self._format_dataset_item_response(response)
205+
return format_dataset_item_response(response)
207206

208207
def create_slice(
209208
self,
@@ -245,25 +244,6 @@ def delete_item(self, item_id: str = None, reference_id: str = None):
245244
def list_autotags(self):
246245
return self._client.list_autotags(self.id)
247246

248-
def _format_dataset_item_response(self, response: dict) -> dict:
249-
item = response.get(ITEM_KEY, None)
250-
annotation_payload = response.get(ANNOTATIONS_KEY, {})
251-
if not item or not annotation_payload:
252-
# An error occured
253-
return response
254-
255-
annotation_response = {}
256-
for annotation_type in ANNOTATION_TYPES:
257-
if annotation_type in annotation_payload:
258-
annotation_response[annotation_type] = [
259-
Annotation.from_json(ann)
260-
for ann in annotation_payload[annotation_type]
261-
]
262-
return {
263-
ITEM_KEY: DatasetItem.from_json(item),
264-
ANNOTATIONS_KEY: annotation_response,
265-
}
266-
267247
def create_custom_index(self, embeddings_url: str):
268248
return self._client.create_custom_index(self.id, embeddings_url)
269249

nucleus/slice.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,9 @@
1-
from __future__ import annotations
2-
31
from typing import List, Iterable, Set, Tuple, Optional
42
from nucleus.dataset_item import DatasetItem
53
from nucleus.annotation import Annotation
4+
from nucleus.utils import format_dataset_item_response
65

7-
from .constants import DEFAULT_ANNOTATION_UPDATE_MODE
8-
9-
10-
def check_annotations_are_in_slice(
11-
annotations: List[Annotation], slice_to_check: Slice
12-
) -> Tuple[bool, Set[str], Set[str]]:
13-
"""Check membership of the annotation targets within this slice.
14-
15-
annotations: Annnotations with ids referring to targets.
16-
slice: The slice to check against.
17-
"""
18-
info = slice_to_check.info()
19-
item_ids_not_found_in_slice = {
20-
annotation.item_id
21-
for annotation in annotations
22-
if annotation.item_id is not None
23-
}.difference({item_metadata["id"] for item_metadata in info})
24-
reference_ids_not_found_in_slice = {
25-
annotation.reference_id
26-
for annotation in annotations
27-
if annotation.reference_id is not None
28-
}.difference({item_metadata["reference_id"] for item_metadata in info})
29-
if item_ids_not_found_in_slice or reference_ids_not_found_in_slice:
30-
annotations_are_in_slice = False
31-
else:
32-
annotations_are_in_slice = True
33-
34-
return (
35-
annotations_are_in_slice,
36-
item_ids_not_found_in_slice,
37-
reference_ids_not_found_in_slice,
38-
)
6+
from .constants import DEFAULT_ANNOTATION_UPDATE_MODE, ITEM_KEY
397

408

419
class Slice:
@@ -106,13 +74,15 @@ def append(
10674
return response
10775

10876
def items_generator(self) -> Iterable[DatasetItem]:
109-
"""Returns an iterable of DatasetItems in this slice."""
77+
"""Returns an iterable of DatasetItem/Annotation dicts."""
11078
info = self.info()
11179
for item_metadata in info["dataset_items"]:
112-
yield self._client.dataitem_loc(
113-
dataset_id=info["dataset_id"],
114-
dataset_item_id=item_metadata["id"],
115-
)
80+
yield format_dataset_item_response(
81+
self._client.dataitem_loc(
82+
dataset_id=info["dataset_id"],
83+
dataset_item_id=item_metadata["id"],
84+
)
85+
)[ITEM_KEY]
11686

11787
def items(self) -> List[DatasetItem]:
11888
"""Returns a list of all DatasetItems in this slice."""
@@ -152,3 +122,34 @@ def annotate(
152122
update=update,
153123
batch_size=batch_size,
154124
)
125+
126+
127+
def check_annotations_are_in_slice(
128+
annotations: List[Annotation], slice_to_check: Slice
129+
) -> Tuple[bool, Set[str], Set[str]]:
130+
"""Check membership of the annotation targets within this slice.
131+
132+
annotations: Annnotations with ids referring to targets.
133+
slice: The slice to check against.
134+
"""
135+
info = slice_to_check.info()
136+
item_ids_not_found_in_slice = {
137+
annotation.item_id
138+
for annotation in annotations
139+
if annotation.item_id is not None
140+
}.difference({item_metadata["id"] for item_metadata in info})
141+
reference_ids_not_found_in_slice = {
142+
annotation.reference_id
143+
for annotation in annotations
144+
if annotation.reference_id is not None
145+
}.difference({item_metadata["reference_id"] for item_metadata in info})
146+
if item_ids_not_found_in_slice or reference_ids_not_found_in_slice:
147+
annotations_are_in_slice = False
148+
else:
149+
annotations_are_in_slice = True
150+
151+
return (
152+
annotations_are_in_slice,
153+
item_ids_not_found_in_slice,
154+
reference_ids_not_found_in_slice,
155+
)

nucleus/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
1+
"""Shared stateless utility function library"""
2+
3+
14
from typing import List, Union, Dict
25

6+
from nucleus.annotation import Annotation
37
from .dataset_item import DatasetItem
48
from .prediction import BoxPrediction, PolygonPrediction
59

10+
from .constants import (
11+
ITEM_KEY,
12+
ANNOTATIONS_KEY,
13+
ANNOTATION_TYPES,
14+
)
15+
616

717
def _get_all_field_values(metadata_list: List[dict], key: str):
818
return {metadata[key] for metadata in metadata_list if key in metadata}
@@ -34,3 +44,29 @@ def suggest_metadata_schema(
3444
entry["type"] = "text"
3545
schema[key] = entry
3646
return schema
47+
48+
49+
def format_dataset_item_response(response: dict) -> dict:
50+
"""Format the raw client response into api objects."""
51+
if ANNOTATIONS_KEY not in response:
52+
raise ValueError(
53+
f"Server response was missing the annotation key: {response}"
54+
)
55+
if ITEM_KEY not in response:
56+
raise ValueError(
57+
f"Server response was missing the item key: {response}"
58+
)
59+
item = response[ITEM_KEY]
60+
annotation_payload = response[ANNOTATIONS_KEY]
61+
62+
annotation_response = {}
63+
for annotation_type in ANNOTATION_TYPES:
64+
if annotation_type in annotation_payload:
65+
annotation_response[annotation_type] = [
66+
Annotation.from_json(ann)
67+
for ann in annotation_payload[annotation_type]
68+
]
69+
return {
70+
ITEM_KEY: DatasetItem.from_json(item),
71+
ANNOTATIONS_KEY: annotation_response,
72+
}

tests/test_dataset.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,15 @@ def test_slice_append(dataset):
197197
all_stored_items = slc.items()
198198

199199
def sort_by_reference_id(items):
200+
# Remove the generated item_ids and standardize
201+
# empty metadata so we can do an equality check.
202+
for item in items:
203+
item.item_id = None
204+
if item.metadata == {}:
205+
item.metadata = None
200206
return sorted(items, key=lambda x: x.reference_id)
201207

202208
breakpoint()
203-
204-
assert tuple(sort_by_reference_id(all_stored_items)) == ds_items
209+
assert sort_by_reference_id(all_stored_items) == sort_by_reference_id(
210+
ds_items[:3]
211+
)

0 commit comments

Comments
 (0)