Skip to content

Commit ab868f0

Browse files
ardilaUbuntu
andauthored
API for updating autotag (#161)
* tests cleaned up and still pass * more docs * fix tests not running in circleci * fix broken test by creating slice Co-authored-by: Ubuntu <diego.ardila@scale.com>
1 parent b9d87c3 commit ab868f0

File tree

5 files changed

+125
-79
lines changed

5 files changed

+125
-79
lines changed

nucleus/dataset.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,28 @@ def list_autotags(self):
700700
"""
701701
return self._client.list_autotags(self.id)
702702

703+
def update_autotag(self, autotag_id):
704+
"""Will rerun inference on all dataset items in the dataset.
705+
For now this endpoint does not try to skip already inferenced items, but this
706+
improvement is planned for the future. This means that for now, you can only
707+
have one job running at time, so please await the result using job.sleep_until_complete()
708+
before launching another job.
709+
710+
Parameters:
711+
autotag_id: Id of the autotag to re-inference. You can figure out which
712+
id you want by using dataset.list_autotags, or by looking at the URL in the
713+
manage autotag page.
714+
715+
Returns:
716+
:class:`AsyncJob`: Asynchronous job object to track processing status.
717+
"""
718+
return AsyncJob.from_json(
719+
payload=self._client.make_request(
720+
{}, f"autotag/{autotag_id}", requests.post
721+
),
722+
client=self._client,
723+
)
724+
703725
def create_custom_index(
704726
self, embeddings_urls: List[str], embedding_dim: int
705727
):

tests/helpers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import time
1+
import os
22
from pathlib import Path
3-
from urllib.parse import urlparse
43

54
from nucleus import BoxPrediction, DatasetItem
65

@@ -402,3 +401,11 @@ def assert_category_prediction_matches_dict(
402401
prediction_instance, prediction_dict
403402
)
404403
assert prediction_instance.confidence == prediction_dict["confidence"]
404+
405+
406+
def running_as_nucleus_pytest_user(client):
407+
if NUCLEUS_PYTEST_USER_ID in client.api_key:
408+
return True
409+
if os.environ.get("NUCLEUS_PYTEST_USER_ID") == NUCLEUS_PYTEST_USER_ID:
410+
return True
411+
return False

tests/test_autotag.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,95 @@
1+
import os
2+
3+
import pytest
4+
5+
from nucleus.dataset import Dataset
6+
from nucleus.errors import NucleusAPIError
7+
from tests.helpers import DATASET_WITH_AUTOTAG, running_as_nucleus_pytest_user
8+
19
# TODO: Test delete_autotag once API support for autotag creation is added.
10+
11+
12+
@pytest.mark.integration
13+
def test_update_autotag(CLIENT):
14+
if running_as_nucleus_pytest_user(CLIENT):
15+
job = Dataset(DATASET_WITH_AUTOTAG, CLIENT).update_autotag(
16+
"tag_c5jwvzzde8c00604mkx0"
17+
)
18+
job.sleep_until_complete()
19+
status = job.status()
20+
assert status["status"] == "Completed"
21+
22+
23+
def test_dataset_export_autotag_training_items(CLIENT):
24+
# This test can only run for the test user who has an indexed dataset.
25+
# TODO: if/when we can create autotags via api, create one instead.
26+
if running_as_nucleus_pytest_user(CLIENT):
27+
dataset = CLIENT.get_dataset(DATASET_WITH_AUTOTAG)
28+
29+
with pytest.raises(NucleusAPIError) as api_error:
30+
dataset.autotag_training_items(autotag_name="NONSENSE_GARBAGE")
31+
assert (
32+
f"The autotag NONSENSE_GARBAGE was not found in dataset {DATASET_WITH_AUTOTAG}"
33+
in str(api_error.value)
34+
)
35+
36+
items = dataset.autotag_training_items(autotag_name="PytestTestTag")
37+
38+
assert "autotagPositiveTrainingItems" in items
39+
assert "autotag" in items
40+
41+
autotagTrainingItems = items["autotagPositiveTrainingItems"]
42+
autotag = items["autotag"]
43+
44+
assert len(autotagTrainingItems) > 0
45+
for item in autotagTrainingItems:
46+
for column in ["ref_id"]:
47+
assert column in item
48+
49+
for column in ["id", "name", "status", "autotag_level"]:
50+
assert column in autotag
51+
52+
53+
def test_export_embeddings(CLIENT):
54+
if running_as_nucleus_pytest_user(CLIENT):
55+
embeddings = Dataset(DATASET_WITH_AUTOTAG, CLIENT).export_embeddings()
56+
assert "embedding_vector" in embeddings[0]
57+
assert "reference_id" in embeddings[0]
58+
59+
60+
def test_dataset_export_autotag_tagged_items(CLIENT):
61+
# This test can only run for the test user who has an indexed dataset.
62+
# TODO: if/when we can create autotags via api, create one instead.
63+
if running_as_nucleus_pytest_user(CLIENT):
64+
dataset = CLIENT.get_dataset(DATASET_WITH_AUTOTAG)
65+
66+
with pytest.raises(NucleusAPIError) as api_error:
67+
dataset.autotag_items(autotag_name="NONSENSE_GARBAGE")
68+
assert (
69+
f"The autotag NONSENSE_GARBAGE was not found in dataset {DATASET_WITH_AUTOTAG}"
70+
in str(api_error.value)
71+
)
72+
73+
items = dataset.autotag_items(autotag_name="PytestTestTag")
74+
75+
assert "autotagItems" in items
76+
assert "autotag" in items
77+
78+
autotagItems = items["autotagItems"]
79+
autotag = items["autotag"]
80+
81+
assert len(autotagItems) > 0
82+
for item in autotagItems:
83+
for column in ["ref_id", "score"]:
84+
assert column in item
85+
86+
for column in ["id", "name", "status", "autotag_level"]:
87+
assert column in autotag
88+
89+
90+
def test_export_slice_embeddings(CLIENT):
91+
if running_as_nucleus_pytest_user(CLIENT):
92+
test_slice = CLIENT.get_slice("slc_c6kcx5mrzr7g0c9d8cng")
93+
embeddings = test_slice.export_embeddings()
94+
assert "embedding_vector" in embeddings[0]
95+
assert "reference_id" in embeddings[0]

tests/test_dataset.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@
3838
from nucleus.prediction import BoxPrediction
3939

4040
from .helpers import (
41-
DATASET_WITH_AUTOTAG,
4241
LOCAL_FILENAME,
43-
NUCLEUS_PYTEST_USER_ID,
4442
TEST_BOX_ANNOTATIONS,
4543
TEST_CATEGORY_ANNOTATIONS,
4644
TEST_DATASET_NAME,
@@ -354,66 +352,6 @@ def test_raises_error_for_duplicate():
354352
)
355353

356354

357-
def test_dataset_export_autotag_tagged_items(CLIENT):
358-
# This test can only run for the test user who has an indexed dataset.
359-
# TODO: if/when we can create autotags via api, create one instead.
360-
if os.environ.get("NUCLEUS_PYTEST_USER_ID") == NUCLEUS_PYTEST_USER_ID:
361-
dataset = CLIENT.get_dataset(DATASET_WITH_AUTOTAG)
362-
363-
with pytest.raises(NucleusAPIError) as api_error:
364-
dataset.autotag_items(autotag_name="NONSENSE_GARBAGE")
365-
assert (
366-
f"The autotag NONSENSE_GARBAGE was not found in dataset {DATASET_WITH_AUTOTAG}"
367-
in str(api_error.value)
368-
)
369-
370-
items = dataset.autotag_items(autotag_name="PytestTestTag")
371-
372-
assert "autotagItems" in items
373-
assert "autotag" in items
374-
375-
autotagItems = items["autotagItems"]
376-
autotag = items["autotag"]
377-
378-
assert len(autotagItems) > 0
379-
for item in autotagItems:
380-
for column in ["ref_id", "score"]:
381-
assert column in item
382-
383-
for column in ["id", "name", "status", "autotag_level"]:
384-
assert column in autotag
385-
386-
387-
def test_dataset_export_autotag_training_items(CLIENT):
388-
# This test can only run for the test user who has an indexed dataset.
389-
# TODO: if/when we can create autotags via api, create one instead.
390-
if os.environ.get("NUCLEUS_PYTEST_USER_ID") == NUCLEUS_PYTEST_USER_ID:
391-
dataset = CLIENT.get_dataset(DATASET_WITH_AUTOTAG)
392-
393-
with pytest.raises(NucleusAPIError) as api_error:
394-
dataset.autotag_training_items(autotag_name="NONSENSE_GARBAGE")
395-
assert (
396-
f"The autotag NONSENSE_GARBAGE was not found in dataset {DATASET_WITH_AUTOTAG}"
397-
in str(api_error.value)
398-
)
399-
400-
items = dataset.autotag_training_items(autotag_name="PytestTestTag")
401-
402-
assert "autotagPositiveTrainingItems" in items
403-
assert "autotag" in items
404-
405-
autotagTrainingItems = items["autotagPositiveTrainingItems"]
406-
autotag = items["autotag"]
407-
408-
assert len(autotagTrainingItems) > 0
409-
for item in autotagTrainingItems:
410-
for column in ["ref_id"]:
411-
assert column in item
412-
413-
for column in ["id", "name", "status", "autotag_level"]:
414-
assert column in autotag
415-
416-
417355
@pytest.mark.integration
418356
def test_annotate_async(dataset: Dataset):
419357
dataset.append(make_dataset_items())
@@ -578,10 +516,3 @@ def sort_labelmap(segmentation_annotation):
578516
exported[0][ANNOTATIONS_KEY][MULTICATEGORY_TYPE][0]
579517
== multicategory_annotation
580518
)
581-
582-
583-
def test_export_embeddings(CLIENT):
584-
if NUCLEUS_PYTEST_USER_ID in CLIENT.api_key:
585-
embeddings = Dataset(DATASET_WITH_AUTOTAG, CLIENT).export_embeddings()
586-
assert "embedding_vector" in embeddings[0]
587-
assert "reference_id" in embeddings[0]

tests/test_slice.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,3 @@ def test_slice_send_to_labeling(dataset):
184184

185185
response = slc.send_to_labeling(TEST_PROJECT_ID)
186186
assert isinstance(response, AsyncJob)
187-
188-
189-
def test_export_slice_embeddings(CLIENT):
190-
test_slice = CLIENT.get_slice("slc_c4s4ts3v7bw00b1hkj0g")
191-
if NUCLEUS_PYTEST_USER_ID in CLIENT.api_key:
192-
embeddings = test_slice.export_embeddings()
193-
assert "embedding_vector" in embeddings[0]
194-
assert "reference_id" in embeddings[0]

0 commit comments

Comments
 (0)