Skip to content

Commit 05173ab

Browse files
authored
add check for duplicate ref_id, ann_id (#296)
1 parent 523ca50 commit 05173ab

File tree

9 files changed

+84
-22
lines changed

9 files changed

+84
-22
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.10.8](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.10.8) - 2022-05-10
9+
10+
### Fixed
11+
- Add checks for duplicate (`reference_id`, `annotation_id`) when uploading Annotations or Predictions
12+
13+
814
## [0.10.7](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.10.7) - 2022-05-09
915

1016
### Fixed

nucleus/annotation_uploader.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from collections import Counter
23
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
34

45
from nucleus.annotation import Annotation, SegmentationAnnotation
@@ -8,6 +9,7 @@
89
make_many_form_data_requests_concurrently,
910
)
1011
from nucleus.constants import MASK_TYPE, SERIALIZED_REQUEST_KEY
12+
from nucleus.errors import DuplicateIDError
1113
from nucleus.payload_constructor import (
1214
construct_annotation_payload,
1315
construct_segmentation_payload,
@@ -208,6 +210,26 @@ def fn():
208210

209211
return fn
210212

213+
@staticmethod
214+
def check_for_duplicate_ids(annotations: Iterable[Annotation]):
215+
"""Do not allow annotations to have the same (annotation_id, reference_id) tuple"""
216+
217+
# some annotations like CategoryAnnotation do not have annotation_id attribute, and as such, we allow duplicates
218+
tuple_ids = [
219+
(ann.reference_id, ann.annotation_id) # type: ignore
220+
for ann in annotations
221+
if hasattr(ann, "annotation_id")
222+
]
223+
tuple_count = Counter(tuple_ids)
224+
duplicates = {key for key, value in tuple_count.items() if value > 1}
225+
if len(duplicates) > 0:
226+
raise DuplicateIDError(
227+
f"Duplicate annotations with the same (reference_id, annotation_id) properties found.\n"
228+
f"Duplicates: {duplicates}\n"
229+
f"To fix this, avoid duplicate annotations, or specify a different annotation_id attribute "
230+
f"for the failing items."
231+
)
232+
211233

212234
class PredictionUploader(AnnotationUploader):
213235
def __init__(

nucleus/dataset.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,9 @@ def annotate(
389389
390390
Otherwise, returns an :class:`AsyncJob` object.
391391
"""
392+
uploader = AnnotationUploader(dataset_id=self.id, client=self._client)
393+
uploader.check_for_duplicate_ids(annotations)
394+
392395
if asynchronous:
393396
check_all_mask_paths_remote(annotations)
394397
request_id = serialize_and_write_to_presigned_url(
@@ -399,7 +402,7 @@ def annotate(
399402
route=f"dataset/{self.id}/annotate?async=1",
400403
)
401404
return AsyncJob.from_json(response, self._client)
402-
uploader = AnnotationUploader(dataset_id=self.id, client=self._client)
405+
403406
return uploader.upload(
404407
annotations=annotations,
405408
update=update,
@@ -1405,6 +1408,14 @@ def upload_predictions(
14051408
"predictions_ignored": int,
14061409
}
14071410
"""
1411+
uploader = PredictionUploader(
1412+
model_run_id=None,
1413+
dataset_id=self.id,
1414+
model_id=model.id,
1415+
client=self._client,
1416+
)
1417+
uploader.check_for_duplicate_ids(predictions)
1418+
14081419
if asynchronous:
14091420
check_all_mask_paths_remote(predictions)
14101421

@@ -1416,21 +1427,15 @@ def upload_predictions(
14161427
route=f"dataset/{self.id}/model/{model.id}/uploadPredictions?async=1",
14171428
)
14181429
return AsyncJob.from_json(response, self._client)
1419-
else:
1420-
uploader = PredictionUploader(
1421-
model_run_id=None,
1422-
dataset_id=self.id,
1423-
model_id=model.id,
1424-
client=self._client,
1425-
)
1426-
return uploader.upload(
1427-
annotations=predictions,
1428-
batch_size=batch_size,
1429-
update=update,
1430-
remote_files_per_upload_request=remote_files_per_upload_request,
1431-
local_files_per_upload_request=local_files_per_upload_request,
1432-
local_file_upload_concurrency=local_file_upload_concurrency,
1433-
)
1430+
1431+
return uploader.upload(
1432+
annotations=predictions,
1433+
batch_size=batch_size,
1434+
update=update,
1435+
remote_files_per_upload_request=remote_files_per_upload_request,
1436+
local_files_per_upload_request=local_files_per_upload_request,
1437+
local_file_upload_concurrency=local_file_upload_concurrency,
1438+
)
14341439

14351440
def predictions_iloc(self, model, index):
14361441
"""Fetches all predictions of a dataset item by its absolute index.

nucleus/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,9 @@ def __init__(
7272
):
7373
self.message = message
7474
super().__init__(self.message)
75+
76+
77+
class DuplicateIDError(Exception):
78+
def __init__(self, message):
79+
self.message = message
80+
super().__init__(self.message)

nucleus/model_run.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ def predict(
154154
"predictions_ignored": int,
155155
}
156156
"""
157+
uploader = PredictionUploader(
158+
model_run_id=self.model_run_id, client=self._client
159+
)
160+
uploader.check_for_duplicate_ids(annotations)
161+
157162
if asynchronous:
158163
check_all_mask_paths_remote(annotations)
159164

@@ -165,9 +170,7 @@ def predict(
165170
route=f"modelRun/{self.model_run_id}/predict?async=1",
166171
)
167172
return AsyncJob.from_json(response, self._client)
168-
uploader = PredictionUploader(
169-
model_run_id=self.model_run_id, client=self._client
170-
)
173+
171174
return uploader.upload(
172175
annotations=annotations,
173176
update=update,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.10.7"
24+
version = "0.10.8"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/test_annotation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
SegmentationAnnotation,
1616
)
1717
from nucleus.constants import ERROR_PAYLOAD
18+
from nucleus.errors import DuplicateIDError
1819
from nucleus.job import AsyncJob, JobError
1920

2021
from .helpers import (
@@ -813,3 +814,10 @@ def test_box_gt_upload_embedding_async(CLIENT, dataset):
813814
status = job.status()
814815
assert status["job_id"] == job.job_id
815816
assert status["status"] == "Running"
817+
818+
819+
def test_annotation_duplicate_ids_fail(dataset):
820+
box_ann = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
821+
annotations = [box_ann, box_ann]
822+
with pytest.raises(DuplicateIDError):
823+
dataset.annotate(annotations=annotations)

tests/test_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,8 @@ def test_dataset_append_async_with_local_path(dataset: Dataset):
314314
dataset.append(ds_items, asynchronous=True)
315315

316316

317-
@pytest.mark.integration
317+
# TODO(Jean): Fix and remove skip, this is a flaky test
318+
@pytest.mark.skip(reason="Flaky test")
318319
def test_dataset_append_async_with_1_bad_url(dataset: Dataset):
319320
ds_items = make_dataset_items()
320321
ds_items[0].image_location = "https://looks.ok.but.is.not.accessible"

tests/test_prediction.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import time
32

43
import pytest
@@ -16,6 +15,7 @@
1615
SegmentationPrediction,
1716
)
1817
from nucleus.constants import ERROR_PAYLOAD
18+
from nucleus.errors import DuplicateIDError
1919
from nucleus.job import AsyncJob, JobError
2020

2121
from .helpers import (
@@ -724,3 +724,14 @@ def test_box_pred_upload_embedding_async(CLIENT, model_run):
724724
status = job.status()
725725
assert status["job_id"] == job.job_id
726726
assert status["status"] == "Running"
727+
728+
729+
def test_prediction_duplicate_ids_fail(dataset, model, model_run):
730+
box_pred = BoxPrediction(**TEST_BOX_PREDICTIONS_EMBEDDINGS[0])
731+
predictions = [box_pred, box_pred]
732+
733+
with pytest.raises(DuplicateIDError):
734+
dataset.upload_predictions(model, predictions=predictions)
735+
736+
with pytest.raises(DuplicateIDError):
737+
model_run.predict(annotations=predictions)

0 commit comments

Comments
 (0)