Skip to content

Commit a1e0b69

Browse files
authored
Merge pull request #43 from scaleapi/sasha/custom_indexing
Support Custom Index
2 parents 2e64ccf + 81e9fbb commit a1e0b69

File tree

5 files changed

+90
-5
lines changed

5 files changed

+90
-5
lines changed

nucleus/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
ANNOTATION_METADATA_SCHEMA_KEY,
120120
ITEM_METADATA_SCHEMA_KEY,
121121
FORCE_KEY,
122+
EMBEDDINGS_URL_KEY,
122123
)
123124
from .model import Model
124125
from .errors import (
@@ -991,6 +992,27 @@ def delete_model(self, model_id: str) -> dict:
991992
)
992993
return response
993994

995+
def create_custom_index(self, dataset_id: str, embeddings_url: str):
996+
return self._make_request(
997+
{EMBEDDINGS_URL_KEY: embeddings_url},
998+
f"indexing/{dataset_id}",
999+
requests_command=requests.post,
1000+
)
1001+
1002+
def check_index_status(self, job_id: str):
1003+
return self._make_request(
1004+
{},
1005+
f"indexing/{job_id}",
1006+
requests_command=requests.get,
1007+
)
1008+
1009+
def delete_custom_index(self, dataset_id: str):
1010+
return self._make_request(
1011+
{},
1012+
f"indexing/{dataset_id}",
1013+
requests_command=requests.delete,
1014+
)
1015+
9941016
def _make_grequest(
9951017
self,
9961018
payload: dict,

nucleus/constants.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
ANNOTATION_UPDATE_KEY = "update"
2222
DEFAULT_ANNOTATION_UPDATE_MODE = False
2323
STATUS_CODE_KEY = "status_code"
24-
SUCCESS_STATUS_CODES = [200, 201]
24+
STATUS_KEY = "status"
25+
SUCCESS_STATUS_CODES = [200, 201, 202]
2526
ERRORS_KEY = "errors"
2627
MODEL_RUN_ID_KEY = "model_run_id"
2728
MODEL_ID_KEY = "model_id"
@@ -56,3 +57,6 @@
5657
MASK_URL_KEY = "mask_url"
5758
INDEX_KEY = "index"
5859
SEGMENTATIONS_KEY = "segmentations"
60+
EMBEDDINGS_URL_KEY = "embeddings_url"
61+
JOB_ID_KEY = "job_id"
62+
MESSAGE_KEY = "message"

nucleus/dataset.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
Annotation,
55
BoxAnnotation,
66
PolygonAnnotation,
7-
SegmentationAnnotation,
87
)
98
from .constants import (
109
DATASET_NAME_KEY,
@@ -17,9 +16,6 @@
1716
ITEM_KEY,
1817
DEFAULT_ANNOTATION_UPDATE_MODE,
1918
ANNOTATIONS_KEY,
20-
BOX_TYPE,
21-
POLYGON_TYPE,
22-
SEGMENTATION_TYPE,
2319
ANNOTATION_TYPES,
2420
)
2521
from .payload_constructor import construct_model_run_creation_payload
@@ -260,3 +256,12 @@ def _format_dataset_item_response(self, response: dict) -> dict:
260256
ITEM_KEY: DatasetItem.from_json(item),
261257
ANNOTATIONS_KEY: annotation_response,
262258
}
259+
260+
def create_custom_index(self, embeddings_url: str):
261+
return self._client.create_custom_index(self.id, embeddings_url)
262+
263+
def delete_custom_index(self):
264+
return self._client.delete_custom_index(self.id)
265+
266+
def check_index_status(self, job_id: str):
267+
return self._client.check_index_status(job_id)

tests/helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def reference_id_from_url(url):
120120
for i in range(len(TEST_POLYGON_ANNOTATIONS))
121121
]
122122

123+
TEST_INDEX_EMBEDDINGS_FILE = "https://scale-ml.s3.amazonaws.com/tmp/sasha/pytest_embeddings_payload.json"
123124

124125
# Asserts that a box annotation instance matches a dict representing its properties.
125126
# Useful to check annotation uploads/updates match.

tests/test_indexing.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
3+
from helpers import (
4+
get_signed_url,
5+
TEST_INDEX_EMBEDDINGS_FILE,
6+
TEST_IMG_URLS,
7+
TEST_DATASET_NAME,
8+
reference_id_from_url,
9+
)
10+
11+
from nucleus import DatasetItem
12+
13+
from nucleus.constants import (
14+
ERROR_PAYLOAD,
15+
JOB_ID_KEY,
16+
MESSAGE_KEY,
17+
STATUS_KEY,
18+
)
19+
20+
21+
@pytest.fixture()
22+
def dataset(CLIENT):
23+
ds = CLIENT.create_dataset(TEST_DATASET_NAME)
24+
ds_items = []
25+
for url in TEST_IMG_URLS:
26+
ds_items.append(
27+
DatasetItem(
28+
image_location=url,
29+
reference_id=reference_id_from_url(url),
30+
)
31+
)
32+
33+
response = ds.append(ds_items)
34+
assert ERROR_PAYLOAD not in response.json()
35+
yield ds
36+
37+
response = CLIENT.delete_dataset(ds.id)
38+
assert response == {}
39+
40+
41+
def test_index_integration(dataset):
42+
signed_embeddings_url = get_signed_url(TEST_INDEX_EMBEDDINGS_FILE)
43+
create_response = dataset.create_custom_index(signed_embeddings_url)
44+
assert JOB_ID_KEY in create_response
45+
assert MESSAGE_KEY in create_response
46+
job_id = create_response[JOB_ID_KEY]
47+
48+
# Job can error because pytest dataset fixture gets deleted
49+
# As a workaround, we'll just check htat we got some response
50+
job_status_response = dataset.check_index_status(job_id)
51+
assert STATUS_KEY in job_status_response
52+
assert JOB_ID_KEY in job_status_response
53+
assert MESSAGE_KEY in job_status_response

0 commit comments

Comments
 (0)