Skip to content

Commit 5fffaef

Browse files
author
Diego Ardila
committed
First try :)
1 parent b664115 commit 5fffaef

File tree

2 files changed

+96
-3
lines changed

2 files changed

+96
-3
lines changed

nucleus/dataset.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
serialize_and_write_to_presigned_url,
99
)
1010

11-
from .annotation import Annotation
11+
from .annotation import Annotation, check_all_annotation_paths_remote
1212
from .constants import (
1313
DATASET_ITEM_IDS_KEY,
1414
DATASET_LENGTH_KEY,
1515
DATASET_MODEL_RUNS_KEY,
1616
DATASET_NAME_KEY,
1717
DATASET_SLICES_KEY,
1818
DEFAULT_ANNOTATION_UPDATE_MODE,
19+
JOB_ID_KEY,
1920
NAME_KEY,
2021
REFERENCE_IDS_KEY,
2122
REQUEST_ID_KEY,
@@ -143,7 +144,8 @@ def annotate(
143144
annotations: List[Annotation],
144145
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
145146
batch_size: int = 5000,
146-
) -> dict:
147+
asynchronous: bool = False,
148+
) -> Union[dict[str, Any], AsyncJob]:
147149
"""
148150
Uploads ground truth annotations for a given dataset.
149151
:param annotations: ground truth annotations for a given dataset to upload
@@ -156,6 +158,19 @@ def annotate(
156158
"ignored_items": int,
157159
}
158160
"""
161+
if asynchronous:
162+
check_all_annotation_paths_remote(annotations)
163+
164+
request_id = serialize_and_write_to_presigned_url(
165+
annotations, self.id, self._client
166+
)
167+
response = self._client.make_request(
168+
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
169+
route=f"dataset/{self.id}/annotate?async=1",
170+
)
171+
172+
return AsyncJob(response[JOB_ID_KEY], self._client)
173+
159174
return self._client.annotate_dataset(
160175
self.id, annotations, update=update, batch_size=batch_size
161176
)

tests/test_dataset.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1-
from nucleus.job import JobError
1+
from nucleus.annotation import (
2+
BoxAnnotation,
3+
PolygonAnnotation,
4+
SegmentationAnnotation,
5+
)
6+
from nucleus.job import AsyncJob, JobError
27
import pytest
38
import os
49

510
from .helpers import (
11+
TEST_BOX_ANNOTATIONS,
12+
TEST_POLYGON_ANNOTATIONS,
13+
TEST_SEGMENTATION_ANNOTATIONS,
614
TEST_SLICE_NAME,
715
TEST_DATASET_NAME,
816
TEST_IMG_URLS,
@@ -238,3 +246,73 @@ def test_dataset_export_autotag_scores(CLIENT):
238246
for column in ["dataset_item_ids", "ref_ids", "scores"]:
239247
assert column in scores
240248
assert len(scores[column]) > 0
249+
250+
251+
def test_annotate_async(dataset: Dataset):
252+
semseg = SegmentationAnnotation.from_json(TEST_SEGMENTATION_ANNOTATIONS[0])
253+
polygon = PolygonAnnotation(**TEST_POLYGON_ANNOTATIONS[0])
254+
bbox = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
255+
bbox.reference_id = "fake_garbage"
256+
257+
job: AsyncJob = dataset.annotate(
258+
annotations=[semseg, polygon, bbox],
259+
asynchronous=True,
260+
)
261+
job.sleep_until_complete()
262+
263+
assert job.status() == {
264+
"job_id": job.id,
265+
"status": "Completed",
266+
"message": {
267+
"annotation_upload": {
268+
"epoch": 1,
269+
"total": 2,
270+
"errored": 0,
271+
"ignored": 0,
272+
"datasetId": dataset.id,
273+
"processed": 2,
274+
},
275+
"segmentation_upload": {
276+
"errors": [],
277+
"ignored": 0,
278+
"n_errors": 0,
279+
"processed": 1,
280+
},
281+
},
282+
}
283+
284+
285+
def test_annotate_async_with_error(dataset: Dataset):
286+
semseg = SegmentationAnnotation.from_json(TEST_SEGMENTATION_ANNOTATIONS[0])
287+
polygon = PolygonAnnotation(**TEST_POLYGON_ANNOTATIONS[0])
288+
bbox = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
289+
bbox.reference_id = "fake_garbage"
290+
291+
job: AsyncJob = dataset.annotate(
292+
annotations=[semseg, polygon, bbox],
293+
asynchronous=True,
294+
)
295+
job.sleep_until_complete()
296+
297+
assert job.status() == {
298+
"job_id": job.id,
299+
"status": "Completed",
300+
"message": {
301+
"annotation_upload": {
302+
"epoch": 1,
303+
"total": 2,
304+
"errored": 0,
305+
"ignored": 0,
306+
"datasetId": dataset.id,
307+
"processed": 1,
308+
},
309+
"segmentation_upload": {
310+
"errors": [],
311+
"ignored": 0,
312+
"n_errors": 0,
313+
"processed": 1,
314+
},
315+
},
316+
}
317+
318+
assert "Item with id fake_garbage doesn" in str(job.errors())

0 commit comments

Comments
 (0)