Skip to content

Commit becb53b

Browse files
author
Diego Ardila
committed
slice export working
1 parent c89989f commit becb53b

File tree

3 files changed

+78
-60
lines changed

3 files changed

+78
-60
lines changed

nucleus/slice.py

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import Dict, List, Iterable, Set, Tuple, Optional, Union
2-
from nucleus.dataset_item import DatasetItem
1+
from typing import Dict, Iterable, List, Set, Tuple, Union
2+
3+
import requests
4+
35
from nucleus.annotation import Annotation
4-
from nucleus.utils import format_dataset_item_response
6+
from nucleus.dataset_item import DatasetItem
57
from nucleus.job import AsyncJob
6-
7-
from .constants import DEFAULT_ANNOTATION_UPDATE_MODE
8+
from nucleus.utils import convert_export_payload, format_dataset_item_response
89

910

1011
class Slice:
@@ -109,42 +110,12 @@ def items_and_annotations(
109110
* The other value is a dictionary containing all the annotations for this
110111
dataset item, sorted by annotation type.
111112
"""
112-
return list(self.items_and_annotation_generator())
113-
114-
def annotate(
115-
self,
116-
annotations: List[Annotation],
117-
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
118-
batch_size: int = 5000,
119-
strict=True,
120-
):
121-
"""Update annotations within this slice.
122-
123-
Args:
124-
annotations: List of annotations to upload
125-
batch_size: How many annotations to send per request.
126-
strict: Whether to first check that the annotations belong to this slice.
127-
Set to false to avoid this check and speed up upload.
128-
"""
129-
if strict:
130-
(
131-
annotations_are_in_slice,
132-
item_ids_not_found_in_slice,
133-
reference_ids_not_found_in_slice,
134-
) = check_annotations_are_in_slice(annotations, self)
135-
if not annotations_are_in_slice:
136-
message = "Not all annotations are in this slice.\n"
137-
if item_ids_not_found_in_slice:
138-
message += f"Item ids not found in slice: {item_ids_not_found_in_slice} \n"
139-
if reference_ids_not_found_in_slice:
140-
message += f"Reference ids not found in slice: {reference_ids_not_found_in_slice}"
141-
raise ValueError(message)
142-
self._client.annotate_dataset(
143-
dataset_id=self.dataset_id,
144-
annotations=annotations,
145-
update=update,
146-
batch_size=batch_size,
113+
api_payload = self._client.make_request(
114+
payload=None,
115+
route=f"slice/{self.slice_id}/exportForTraining",
116+
requests_command=requests.get,
147117
)
118+
return convert_export_payload(api_payload["exportedRows"])
148119

149120
def send_to_labeling(self, project_id: str):
150121
response = self._client.make_request(

nucleus/utils.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,29 @@
11
"""Shared stateless utility function library"""
22

3-
3+
from collections import defaultdict
44
import io
55
import uuid
66
from typing import IO, Dict, List, Sequence, Union
77

88
import requests
99
from requests.models import HTTPError
1010

11-
from nucleus.annotation import Annotation
12-
13-
from .constants import ANNOTATION_TYPES, ANNOTATIONS_KEY, ITEM_KEY
11+
from nucleus.annotation import (
12+
Annotation,
13+
BoxAnnotation,
14+
PolygonAnnotation,
15+
SegmentationAnnotation,
16+
)
17+
18+
from .constants import (
19+
ANNOTATION_TYPES,
20+
ANNOTATIONS_KEY,
21+
BOX_TYPE,
22+
ITEM_KEY,
23+
POLYGON_TYPE,
24+
REFERENCE_ID_KEY,
25+
SEGMENTATION_TYPE,
26+
)
1427
from .dataset_item import DatasetItem
1528
from .prediction import BoxPrediction, PolygonPrediction
1629

@@ -73,6 +86,31 @@ def format_dataset_item_response(response: dict) -> dict:
7386
}
7487

7588

89+
def convert_export_payload(api_payload):
90+
return_payload = []
91+
for row in api_payload:
92+
return_payload_row = {}
93+
return_payload_row[ITEM_KEY] = DatasetItem.from_json(row[ITEM_KEY])
94+
annotations = defaultdict(list)
95+
if row[SEGMENTATION_TYPE] is not None:
96+
segmentation = row[SEGMENTATION_TYPE]
97+
segmentation[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
98+
annotations[SEGMENTATION_TYPE] = SegmentationAnnotation.from_json(
99+
segmentation
100+
)
101+
for polygon in row[POLYGON_TYPE]:
102+
polygon[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
103+
annotations[POLYGON_TYPE].append(
104+
PolygonAnnotation.from_json(polygon)
105+
)
106+
for box in row[BOX_TYPE]:
107+
box[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
108+
annotations[BOX_TYPE].append(BoxAnnotation.from_json(box))
109+
return_payload_row[ANNOTATIONS_KEY] = annotations
110+
return_payload.append(return_payload_row)
111+
return return_payload
112+
113+
76114
def serialize_and_write(
77115
upload_units: Sequence[Union[DatasetItem, Annotation]], file_pointer
78116
):

tests/test_slice.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
import copy
12
import pytest
23
from nucleus import Slice, NucleusClient, DatasetItem, BoxAnnotation
3-
from nucleus.constants import ERROR_PAYLOAD, ITEM_KEY
4+
from nucleus.constants import (
5+
ANNOTATIONS_KEY,
6+
BOX_TYPE,
7+
ERROR_PAYLOAD,
8+
ITEM_KEY,
9+
)
410
from .helpers import (
511
TEST_DATASET_NAME,
612
TEST_IMG_URLS,
@@ -64,38 +70,41 @@ def test_slice_create_and_delete_and_list(dataset):
6470
)
6571

6672

67-
def test_slice_create_and_annotate(dataset):
73+
def test_slice_create_and_export(dataset):
6874
# Dataset upload
6975
url = TEST_IMG_URLS[0]
7076
annotation_in_slice = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
71-
annotation_not_in_slice = BoxAnnotation(**TEST_BOX_ANNOTATIONS[1])
7277

73-
ds_items = []
74-
ds_items.append(
78+
ds_items = [
7579
DatasetItem(
7680
image_location=url,
7781
reference_id=reference_id_from_url(url),
78-
)
79-
)
82+
metadata={"test": "metadata"},
83+
),
84+
DatasetItem(
85+
image_location=url,
86+
reference_id="different_item",
87+
metadata={"test": "metadata"},
88+
),
89+
]
8090
response = dataset.append(ds_items)
8191
assert ERROR_PAYLOAD not in response.json()
8292

8393
# Slice creation
8494
slc = dataset.create_slice(
8595
name=TEST_SLICE_NAME,
86-
reference_ids=[item.reference_id for item in ds_items[:2]],
96+
reference_ids=[item.reference_id for item in ds_items[:1]],
8797
)
8898

89-
slc.annotate(annotations=[annotation_in_slice])
90-
with pytest.raises(ValueError) as not_in_slice_error:
91-
slc.annotate(annotations=[annotation_not_in_slice])
99+
dataset.annotate(annotations=[annotation_in_slice])
92100

93-
assert (
94-
annotation_not_in_slice.reference_id
95-
in not_in_slice_error.value.args[0]
96-
)
101+
expected_box_annotation = copy.deepcopy(annotation_in_slice)
102+
expected_box_annotation.annotation_id = None
103+
expected_box_annotation.metadata = {}
97104

98-
slc.annotate(annotations=[annotation_not_in_slice], strict=False)
105+
exported = slc.items_and_annotations()
106+
assert exported[0][ITEM_KEY] == ds_items[0]
107+
assert exported[0][ANNOTATIONS_KEY][BOX_TYPE][0] == expected_box_annotation
99108

100109

101110
def test_slice_append(dataset):

0 commit comments

Comments
 (0)