Skip to content

Commit 7d8037f

Browse files
author
Diego Ardila
committed
Tests working
1 parent 17fb35b commit 7d8037f

File tree

5 files changed

+63
-7
lines changed

5 files changed

+63
-7
lines changed

nucleus/dataset.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,24 @@ def items_and_annotations(
430430
)
431431
return convert_export_payload(api_payload[EXPORTED_ROWS])
432432

433+
def export_embeddings(
434+
self,
435+
) -> List[Dict[str, Union[str, List[float]]]]:
436+
"""Returns a pd.Dataframe-ready format of dataset embeddings.
437+
438+
Returns:
439+
A list, where each item is a dict with two keys representing a row
440+
in the dataset.
441+
* One value in the dict is the reference id
442+
* The other value is a list of the embedding values
443+
"""
444+
api_payload = self._client.make_request(
445+
payload=None,
446+
route=f"dataset/{self.id}/embeddings",
447+
requests_command=requests.get,
448+
)
449+
return api_payload
450+
433451
def delete_annotations(
434452
self, reference_ids: list = None, keep_history=False
435453
):

nucleus/slice.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,24 @@ def send_to_labeling(self, project_id: str):
126126
)
127127
return AsyncJob.from_json(response, self._client)
128128

129+
def export_embeddings(
130+
self,
131+
) -> List[Dict[str, Union[str, List[float]]]]:
132+
"""Returns a pd.Dataframe-ready format of dataset embeddings.
133+
134+
Returns:
135+
A list, where each item is a dict with two keys representing a row
136+
in the dataset.
137+
* One value in the dict is the reference id
138+
* The other value is a list of the embedding values
139+
"""
140+
api_payload = self._client.make_request(
141+
payload=None,
142+
route=f"slice/{self.slice_id}/embeddings",
143+
requests_command=requests.get,
144+
)
145+
return api_payload
146+
129147

130148
def check_annotations_are_in_slice(
131149
annotations: List[Annotation], slice_to_check: Slice

tests/helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
TEST_SLICE_NAME = "[PyTest] Test Slice"
1414
TEST_PROJECT_ID = "60b699d70f139e002dd31bfc"
1515

16+
DATASET_WITH_AUTOTAG = "ds_c4dgj702e2vjft7m9xa0"
17+
NUCLEUS_PYTEST_USER_ID = "60ad648c85db770026e9bf77"
18+
1619

1720
TEST_IMG_URLS = [
1821
"https://github.com/scaleapi/nucleus-python-client/raw/master/tests/testdata/airplane.jpeg",

tests/test_dataset.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@
3737
TEST_IMG_URLS,
3838
TEST_POLYGON_ANNOTATIONS,
3939
TEST_SEGMENTATION_ANNOTATIONS,
40+
DATASET_WITH_AUTOTAG,
41+
NUCLEUS_PYTEST_USER_ID,
4042
reference_id_from_url,
4143
)
4244

43-
TEST_AUTOTAG_DATASET = "ds_bz43jm2jwm70060b3890"
44-
4545

4646
def test_reprs():
4747
# Have to define here in order to have access to all relevant objects
@@ -325,17 +325,17 @@ def test_raises_error_for_duplicate():
325325
def test_dataset_export_autotag_scores(CLIENT):
326326
# This test can only run for the test user who has an indexed dataset.
327327
# TODO: if/when we can create autotags via api, create one instead.
328-
if os.environ.get("HAS_ACCESS_TO_TEST_DATA", False):
329-
dataset = CLIENT.get_dataset(TEST_AUTOTAG_DATASET)
328+
if NUCLEUS_PYTEST_USER_ID in CLIENT.api_key:
329+
dataset = CLIENT.get_dataset(DATASET_WITH_AUTOTAG)
330330

331331
with pytest.raises(NucleusAPIError) as api_error:
332332
dataset.autotag_scores(autotag_name="NONSENSE_GARBAGE")
333333
assert (
334-
f"The autotag NONSENSE_GARBAGE was not found in dataset {TEST_AUTOTAG_DATASET}"
334+
f"The autotag NONSENSE_GARBAGE was not found in dataset {DATASET_WITH_AUTOTAG}"
335335
in str(api_error.value)
336336
)
337337

338-
scores = dataset.autotag_scores(autotag_name="TestTag")
338+
scores = dataset.autotag_scores(autotag_name="PytestTestTag")
339339

340340
for column in ["dataset_item_ids", "ref_ids", "scores"]:
341341
assert column in scores
@@ -484,3 +484,10 @@ def sort_labelmap(segmentation_annotation):
484484
assert exported[0][ANNOTATIONS_KEY][POLYGON_TYPE][0] == clear_fields(
485485
polygon_annotation
486486
)
487+
488+
489+
def test_export_embeddings(CLIENT):
490+
if NUCLEUS_PYTEST_USER_ID in CLIENT.api_key:
491+
embeddings = Dataset(DATASET_WITH_AUTOTAG, CLIENT).export_embeddings()
492+
assert "embedding_vector" in embeddings[0]
493+
assert "reference_id" in embeddings[0]

tests/test_slice.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import pytest
33
import uuid
4-
from nucleus import Slice, NucleusClient, DatasetItem, BoxAnnotation
4+
from nucleus import Slice, NucleusClient, DatasetItem, BoxAnnotation, Dataset
55
from nucleus.constants import (
66
ANNOTATIONS_KEY,
77
BOX_TYPE,
@@ -14,6 +14,8 @@
1414
TEST_SLICE_NAME,
1515
TEST_BOX_ANNOTATIONS,
1616
TEST_PROJECT_ID,
17+
DATASET_WITH_AUTOTAG,
18+
NUCLEUS_PYTEST_USER_ID,
1719
reference_id_from_url,
1820
)
1921
from nucleus.job import AsyncJob
@@ -181,3 +183,11 @@ def test_slice_send_to_labeling(dataset):
181183

182184
response = slc.send_to_labeling(TEST_PROJECT_ID)
183185
assert isinstance(response, AsyncJob)
186+
187+
188+
def test_export_slice_embeddings(CLIENT):
189+
test_slice = CLIENT.get_slice("slc_c4s4ts3v7bw00b1hkj0g")
190+
if NUCLEUS_PYTEST_USER_ID in CLIENT.api_key:
191+
embeddings = test_slice.export_embeddings()
192+
assert "embedding_vector" in embeddings[0]
193+
assert "reference_id" in embeddings[0]

0 commit comments

Comments
 (0)