Skip to content

Commit e6e4d85

Browse files
author
Diego Ardila
committed
Merge master
2 parents 5d561b7 + 2e515a4 commit e6e4d85

File tree

4 files changed

+275
-98
lines changed

4 files changed

+275
-98
lines changed

nucleus/dataset.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import List, Dict, Any, Optional, Union
1+
from typing import List, Dict, Any, Optional
2+
3+
from nucleus.utils import format_dataset_item_response
24
from .dataset_item import DatasetItem
35
from .annotation import (
46
Annotation,
5-
BoxAnnotation,
6-
PolygonAnnotation,
77
)
88
from .constants import (
99
DATASET_NAME_KEY,
@@ -13,10 +13,7 @@
1313
DATASET_ITEM_IDS_KEY,
1414
REFERENCE_IDS_KEY,
1515
NAME_KEY,
16-
ITEM_KEY,
1716
DEFAULT_ANNOTATION_UPDATE_MODE,
18-
ANNOTATIONS_KEY,
19-
ANNOTATION_TYPES,
2017
)
2118
from .payload_constructor import construct_model_run_creation_payload
2219

@@ -130,7 +127,7 @@ def create_model_run(
130127

131128
def annotate(
132129
self,
133-
annotations: List[Union[BoxAnnotation, PolygonAnnotation]],
130+
annotations: List[Annotation],
134131
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
135132
batch_size: int = 5000,
136133
) -> dict:
@@ -200,7 +197,7 @@ def iloc(self, i: int) -> dict:
200197
}
201198
"""
202199
response = self._client.dataitem_iloc(self.id, i)
203-
return self._format_dataset_item_response(response)
200+
return format_dataset_item_response(response)
204201

205202
def refloc(self, reference_id: str) -> dict:
206203
"""
@@ -213,7 +210,7 @@ def refloc(self, reference_id: str) -> dict:
213210
}
214211
"""
215212
response = self._client.dataitem_ref_id(self.id, reference_id)
216-
return self._format_dataset_item_response(response)
213+
return format_dataset_item_response(response)
217214

218215
def loc(self, dataset_item_id: str) -> dict:
219216
"""
@@ -226,7 +223,7 @@ def loc(self, dataset_item_id: str) -> dict:
226223
}
227224
"""
228225
response = self._client.dataitem_loc(self.id, dataset_item_id)
229-
return self._format_dataset_item_response(response)
226+
return format_dataset_item_response(response)
230227

231228
def create_slice(
232229
self,
@@ -268,25 +265,6 @@ def delete_item(self, item_id: str = None, reference_id: str = None):
268265
def list_autotags(self):
269266
return self._client.list_autotags(self.id)
270267

271-
def _format_dataset_item_response(self, response: dict) -> dict:
272-
item = response.get(ITEM_KEY, None)
273-
annotation_payload = response.get(ANNOTATIONS_KEY, {})
274-
if not item or not annotation_payload:
275-
# An error occured
276-
return response
277-
278-
annotation_response = {}
279-
for annotation_type in ANNOTATION_TYPES:
280-
if annotation_type in annotation_payload:
281-
annotation_response[annotation_type] = [
282-
Annotation.from_json(ann)
283-
for ann in annotation_payload[annotation_type]
284-
]
285-
return {
286-
ITEM_KEY: DatasetItem.from_json(item),
287-
ANNOTATIONS_KEY: annotation_response,
288-
}
289-
290268
def create_custom_index(self, embeddings_url: str):
291269
return self._client.create_custom_index(self.id, embeddings_url)
292270

nucleus/slice.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from typing import List
1+
from typing import Dict, List, Iterable, Set, Tuple, Optional, Union
2+
from nucleus.dataset_item import DatasetItem
3+
from nucleus.annotation import Annotation
4+
from nucleus.utils import format_dataset_item_response
5+
6+
from .constants import DEFAULT_ANNOTATION_UPDATE_MODE
27

38

49
class Slice:
@@ -9,6 +14,7 @@ class Slice:
914
def __init__(self, slice_id: str, client):
1015
self.slice_id = slice_id
1116
self._client = client
17+
self._dataset_id = None
1218

1319
def __repr__(self):
1420
return f"Slice(slice_id='{self.slice_id}', client={self._client})"
@@ -19,6 +25,13 @@ def __eq__(self, other):
1925
return True
2026
return False
2127

28+
@property
29+
def dataset_id(self):
30+
"""The id of the dataset this slice belongs to."""
31+
if self._dataset_id is None:
32+
self.info()
33+
return self._dataset_id
34+
2235
def info(self) -> dict:
2336
"""
2437
This endpoint provides information about specified slice.
@@ -30,7 +43,9 @@ def info(self) -> dict:
3043
"dataset_items",
3144
}
3245
"""
33-
return self._client.slice_info(self.slice_id)
46+
info = self._client.slice_info(self.slice_id)
47+
self._dataset_id = info["dataset_id"]
48+
return info
3449

3550
def append(
3651
self,
@@ -57,3 +72,118 @@ def append(
5772
reference_ids=reference_ids,
5873
)
5974
return response
75+
76+
def items_and_annotation_generator(
77+
self,
78+
) -> Iterable[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
79+
"""Returns an iterable of all DatasetItems and Annotations in this slice.
80+
81+
Returns:
82+
An iterable, where each item is a dict with two keys representing a row
83+
in the dataset.
84+
* One value in the dict is the DatasetItem, containing a reference to the
85+
item that was annotated, for example an image_url.
86+
* The other value is a dictionary containing all the annotations for this
87+
dataset item, sorted by annotation type.
88+
"""
89+
info = self.info()
90+
for item_metadata in info["dataset_items"]:
91+
yield format_dataset_item_response(
92+
self._client.dataitem_loc(
93+
dataset_id=info["dataset_id"],
94+
dataset_item_id=item_metadata["id"],
95+
)
96+
)
97+
98+
def items_and_annotations(
99+
self,
100+
) -> List[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
101+
"""Returns a list of all DatasetItems and Annotations in this slice.
102+
103+
Returns:
104+
A list, where each item is a dict with two keys representing a row
105+
in the dataset.
106+
* One value in the dict is the DatasetItem, containing a reference to the
107+
item that was annotated.
108+
* The other value is a dictionary containing all the annotations for this
109+
dataset item, sorted by annotation type.
110+
"""
111+
return list(self.items_and_annotation_generator())
112+
113+
def annotate(
114+
self,
115+
annotations: List[Annotation],
116+
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
117+
batch_size: int = 5000,
118+
strict=True,
119+
):
120+
"""Update annotations within this slice.
121+
122+
Args:
123+
annotations: List of annotations to upload
124+
batch_size: How many annotations to send per request.
125+
strict: Whether to first check that the annotations belong to this slice.
126+
Set to false to avoid this check and speed up upload.
127+
"""
128+
if strict:
129+
(
130+
annotations_are_in_slice,
131+
item_ids_not_found_in_slice,
132+
reference_ids_not_found_in_slice,
133+
) = check_annotations_are_in_slice(annotations, self)
134+
if not annotations_are_in_slice:
135+
message = "Not all annotations are in this slice.\n"
136+
if item_ids_not_found_in_slice:
137+
message += f"Item ids not found in slice: {item_ids_not_found_in_slice} \n"
138+
if reference_ids_not_found_in_slice:
139+
message += f"Reference ids not found in slice: {reference_ids_not_found_in_slice}"
140+
raise ValueError(message)
141+
self._client.annotate_dataset(
142+
dataset_id=self.dataset_id,
143+
annotations=annotations,
144+
update=update,
145+
batch_size=batch_size,
146+
)
147+
148+
149+
def check_annotations_are_in_slice(
150+
annotations: List[Annotation], slice_to_check: Slice
151+
) -> Tuple[bool, Set[str], Set[str]]:
152+
"""Check membership of the annotation targets within this slice.
153+
154+
annotations: Annnotations with ids referring to targets.
155+
slice: The slice to check against.
156+
157+
158+
Returns:
159+
A tuple, where the first element is true/false whether the annotations are all
160+
in the slice.
161+
The second element is the list of item_ids not in the slice.
162+
The third element is the list of ref_ids not in the slice.
163+
"""
164+
info = slice_to_check.info()
165+
166+
item_ids_not_found_in_slice = {
167+
annotation.item_id
168+
for annotation in annotations
169+
if annotation.item_id is not None
170+
}.difference(
171+
{item_metadata["id"] for item_metadata in info["dataset_items"]}
172+
)
173+
reference_ids_not_found_in_slice = {
174+
annotation.reference_id
175+
for annotation in annotations
176+
if annotation.reference_id is not None
177+
}.difference(
178+
{item_metadata["ref_id"] for item_metadata in info["dataset_items"]}
179+
)
180+
if item_ids_not_found_in_slice or reference_ids_not_found_in_slice:
181+
annotations_are_in_slice = False
182+
else:
183+
annotations_are_in_slice = True
184+
185+
return (
186+
annotations_are_in_slice,
187+
item_ids_not_found_in_slice,
188+
reference_ids_not_found_in_slice,
189+
)

tests/test_dataset.py

Lines changed: 3 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -129,74 +129,11 @@ def test_dataset_list_autotags(CLIENT, dataset):
129129
assert autotag_response == []
130130

131131

132-
def test_dataset_export_autotag_scores_raises_not_found(CLIENT):
132+
def test_dataset_export_autotag_scores(CLIENT):
133+
# Pandoc dataset.
133134
client.get_dataset("ds_bwhjbyfb8mjj0ykagxf0")
134135

135136
# TODO: if/when we can create autotags via api, create one instead.
136137
dataset.autotag_scores(autotag_name="red_car_v2")
137138

138-
139-
def test_slice_create_and_delete_and_list(dataset):
140-
# Dataset upload
141-
ds_items = []
142-
for url in TEST_IMG_URLS:
143-
ds_items.append(
144-
DatasetItem(
145-
image_location=url,
146-
reference_id=reference_id_from_url(url),
147-
)
148-
)
149-
response = dataset.append(ds_items)
150-
assert ERROR_PAYLOAD not in response.json()
151-
152-
# Slice creation
153-
slc = dataset.create_slice(
154-
name=TEST_SLICE_NAME,
155-
reference_ids=[item.reference_id for item in ds_items[:2]],
156-
)
157-
158-
dataset_slices = dataset.slices
159-
assert len(dataset_slices) == 1
160-
assert slc.slice_id == dataset_slices[0]
161-
162-
response = slc.info()
163-
assert response["name"] == TEST_SLICE_NAME
164-
assert response["dataset_id"] == dataset.id
165-
assert len(response["dataset_items"]) == 2
166-
for item in ds_items[:2]:
167-
assert (
168-
item.reference_id == response["dataset_items"][0]["ref_id"]
169-
or item.reference_id == response["dataset_items"][1]["ref_id"]
170-
)
171-
172-
173-
def test_slice_append(dataset):
174-
# Dataset upload
175-
ds_items = []
176-
for url in TEST_IMG_URLS:
177-
ds_items.append(
178-
DatasetItem(
179-
image_location=url,
180-
reference_id=reference_id_from_url(url),
181-
)
182-
)
183-
response = dataset.append(ds_items)
184-
assert ERROR_PAYLOAD not in response.json()
185-
186-
# Slice creation
187-
slc = dataset.create_slice(
188-
name=TEST_SLICE_NAME,
189-
reference_ids=[ds_items[0].reference_id],
190-
)
191-
192-
# Insert duplicate first item
193-
slc.append(reference_ids=[item.reference_id for item in ds_items[:3]])
194-
195-
response = slc.info()
196-
assert len(response["dataset_items"]) == 3
197-
for item in ds_items[:3]:
198-
assert (
199-
item.reference_id == response["dataset_items"][0]["ref_id"]
200-
or item.reference_id == response["dataset_items"][1]["ref_id"]
201-
or item.reference_id == response["dataset_items"][2]["ref_id"]
202-
)
139+
# TODO: add some asserts?

0 commit comments

Comments
 (0)