Skip to content

Commit 879a2db

Browse files
authored
dataset.get_{image|object}_index_status (#230)
Grabs the indexing status of your primary image or object index - # of total items and # of items indexed
1 parent b8e7139 commit 879a2db

File tree

5 files changed

+82
-8
lines changed

5 files changed

+82
-8
lines changed

nucleus/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,10 @@ def make_request(
841841
if payload is None:
842842
payload = {}
843843
if requests_command is requests.get:
844+
if payload:
845+
print(
846+
"Received defined payload with GET request! Will ignore payload"
847+
)
844848
payload = None
845849
return self._connection.make_request(payload, route, requests_command) # type: ignore
846850

nucleus/dataset.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,45 @@ def set_continuous_indexing(self, enable: bool = True):
10071007

10081008
return response
10091009

1010+
def get_image_indexing_status(self):
1011+
"""Gets the primary image index progress for the dataset.
1012+
1013+
Returns:
1014+
Response payload::
1015+
1016+
{
1017+
"embedding_count": int
1018+
"image_count": int
1019+
"percent_indexed": float
1020+
"additional_context": str
1021+
}
1022+
"""
1023+
return self._client.make_request(
1024+
{"image": True},
1025+
f"dataset/{self.id}/indexingStatus",
1026+
requests_command=requests.post,
1027+
)
1028+
1029+
def get_object_indexing_status(self, model_run_id=None):
1030+
"""Gets the primary object index progress of the dataset.
1031+
If model_run_id is not specified, this endpoint will retrieve the indexing progress of the ground truth objects.
1032+
1033+
Returns:
1034+
Response payload::
1035+
1036+
{
1037+
"embedding_count": int
1038+
"object_count": int
1039+
"percent_indexed": float
1040+
"additional_context": str
1041+
}
1042+
"""
1043+
return self._client.make_request(
1044+
{"image": False, "model_run_id": model_run_id},
1045+
f"dataset/{self.id}/indexingStatus",
1046+
requests_command=requests.post,
1047+
)
1048+
10101049
def create_image_index(self):
10111050
"""Creates or updates image index by generating embeddings for images that do not already have embeddings.
10121051

tests/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
TEST_SLICE_NAME = "[PyTest] Test Slice"
1717
TEST_PROJECT_ID = "60b699d70f139e002dd31bfc"
1818

19-
DATASET_WITH_AUTOTAG = "ds_c8jwdhy4y4f0078hzceg"
19+
DATASET_WITH_EMBEDDINGS = "ds_c8jwdhy4y4f0078hzceg"
2020
NUCLEUS_PYTEST_USER_ID = "60ad648c85db770026e9bf77"
2121

2222
EVAL_FUNCTION_THRESHOLD = 0.5

tests/test_autotag.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44

55
from nucleus.dataset import Dataset
66
from nucleus.errors import NucleusAPIError
7-
from tests.helpers import DATASET_WITH_AUTOTAG, running_as_nucleus_pytest_user
7+
from tests.helpers import (
8+
DATASET_WITH_EMBEDDINGS,
9+
running_as_nucleus_pytest_user,
10+
)
811

912
# TODO: Test delete_autotag once API support for autotag creation is added.
1013

1114

1215
@pytest.mark.integration
1316
def test_update_autotag(CLIENT):
1417
if running_as_nucleus_pytest_user(CLIENT):
15-
job = Dataset(DATASET_WITH_AUTOTAG, CLIENT).update_autotag(
18+
job = Dataset(DATASET_WITH_EMBEDDINGS, CLIENT).update_autotag(
1619
"tag_c8jwr0rpy1w00e134an0"
1720
)
1821
job.sleep_until_complete()
@@ -24,12 +27,12 @@ def test_dataset_export_autotag_training_items(CLIENT):
2427
# This test can only run for the test user who has an indexed dataset.
2528
# TODO: if/when we can create autotags via api, create one instead.
2629
if running_as_nucleus_pytest_user(CLIENT):
27-
dataset = CLIENT.get_dataset(DATASET_WITH_AUTOTAG)
30+
dataset = CLIENT.get_dataset(DATASET_WITH_EMBEDDINGS)
2831

2932
with pytest.raises(NucleusAPIError) as api_error:
3033
dataset.autotag_training_items(autotag_name="NONSENSE_GARBAGE")
3134
assert (
32-
f"The autotag NONSENSE_GARBAGE was not found in dataset {DATASET_WITH_AUTOTAG}"
35+
f"The autotag NONSENSE_GARBAGE was not found in dataset {DATASET_WITH_EMBEDDINGS}"
3336
in str(api_error.value)
3437
)
3538

@@ -52,7 +55,9 @@ def test_dataset_export_autotag_training_items(CLIENT):
5255

5356
def test_export_embeddings(CLIENT):
5457
if running_as_nucleus_pytest_user(CLIENT):
55-
embeddings = Dataset(DATASET_WITH_AUTOTAG, CLIENT).export_embeddings()
58+
embeddings = Dataset(
59+
DATASET_WITH_EMBEDDINGS, CLIENT
60+
).export_embeddings()
5661
assert "embedding_vector" in embeddings[0]
5762
assert "reference_id" in embeddings[0]
5863

@@ -61,12 +66,12 @@ def test_dataset_export_autotag_tagged_items(CLIENT):
6166
# This test can only run for the test user who has an indexed dataset.
6267
# TODO: if/when we can create autotags via api, create one instead.
6368
if running_as_nucleus_pytest_user(CLIENT):
64-
dataset = CLIENT.get_dataset(DATASET_WITH_AUTOTAG)
69+
dataset = CLIENT.get_dataset(DATASET_WITH_EMBEDDINGS)
6570

6671
with pytest.raises(NucleusAPIError) as api_error:
6772
dataset.autotag_items(autotag_name="NONSENSE_GARBAGE")
6873
assert (
69-
f"The autotag NONSENSE_GARBAGE was not found in dataset {DATASET_WITH_AUTOTAG}"
74+
f"The autotag NONSENSE_GARBAGE was not found in dataset {DATASET_WITH_EMBEDDINGS}"
7075
in str(api_error.value)
7176
)
7277

tests/test_dataset.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from nucleus.job import AsyncJob, JobError
3030

3131
from .helpers import (
32+
DATASET_WITH_EMBEDDINGS,
3233
LOCAL_FILENAME,
3334
TEST_BOX_ANNOTATIONS,
3435
TEST_CATEGORY_ANNOTATIONS,
@@ -556,3 +557,28 @@ def test_dataset_item_iterator(dataset):
556557
}
557558
for key in expected_items:
558559
assert actual_items[key] == expected_items[key]
560+
561+
562+
@pytest.mark.integration
563+
def test_dataset_get_image_indexing_status(CLIENT):
564+
dataset = Dataset(DATASET_WITH_EMBEDDINGS, CLIENT)
565+
resp = dataset.get_image_indexing_status()
566+
print(resp)
567+
assert resp["embedding_count"] == 170
568+
assert resp["image_count"] == 170
569+
assert "object_count" not in resp
570+
assert round(resp["percent_indexed"], 2) == round(
571+
resp["image_count"] / resp["embedding_count"], 2
572+
)
573+
574+
575+
@pytest.mark.integration
576+
def test_dataset_get_object_indexing_status(CLIENT):
577+
dataset = Dataset(DATASET_WITH_EMBEDDINGS, CLIENT)
578+
resp = dataset.get_object_indexing_status()
579+
assert resp["embedding_count"] == 422
580+
assert resp["object_count"] == 423
581+
assert "image_count" not in resp
582+
assert round(resp["percent_indexed"], 2) == round(
583+
resp["object_count"] / resp["embedding_count"], 2
584+
)

0 commit comments

Comments
 (0)