Skip to content

Commit c5f7bf2

Browse files
author
Bihan Jiang
committed
fix failing tests
1 parent bf3592e commit c5f7bf2

File tree

6 files changed

+18
-18
lines changed

6 files changed

+18
-18
lines changed

nucleus/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,19 +1163,14 @@ def create_custom_index(
11631163
embeddings_urls: list of urls, each of which being a json mapping dataset_item_id -> embedding vector
11641164
embedding_dim: the dimension of the embedding vectors, must be consistent for all embedding vectors in the index.
11651165
"""
1166-
response_objects = self.make_request(
1166+
return self.make_request(
11671167
{
11681168
EMBEDDINGS_URL_KEY: embeddings_urls,
11691169
EMBEDDING_DIMENSION_KEY: embedding_dim,
11701170
},
11711171
f"indexing/{dataset_id}",
11721172
requests_command=requests.post,
11731173
)
1174-
job = AsyncJob.from_json(response_objects, self)
1175-
dataset_id = response_objects[DATASET_ID_KEY]
1176-
message = response_objects[MESSAGE_KEY]
1177-
1178-
return dataset_id, job, message
11791174

11801175
def check_index_status(self, job_id: str):
11811176
return self.make_request(

nucleus/job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from nucleus.constants import (
66
JOB_CREATION_TIME_KEY,
77
JOB_ID_KEY,
8-
JOB_STATUS_KEY,
98
JOB_LAST_KNOWN_STATUS_KEY,
109
JOB_TYPE_KEY,
10+
STATUS_KEY,
1111
)
1212

1313
JOB_POLLING_INTERVAL = 5
@@ -27,7 +27,7 @@ def status(self) -> Dict[str, str]:
2727
route=f"job/{self.job_id}",
2828
requests_command=requests.get,
2929
)
30-
self.job_last_known_status = response[JOB_STATUS_KEY]
30+
self.job_last_known_status = response[STATUS_KEY]
3131
return response
3232

3333
def errors(self) -> List[str]:

tests/test_dataset.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,12 @@ def test_dataset_append_async(dataset: Dataset):
169169
status = job.status()
170170
status["message"]["PayloadUrl"] = ""
171171
assert status == {
172-
"job_id": job.id,
172+
"job_id": job.job_id,
173173
"status": "Completed",
174174
"message": {
175175
"PayloadUrl": "",
176176
"image_upload_step": {"errored": 0, "pending": 0, "completed": 5},
177-
"started_image_processing": f"Dataset: {dataset.id}, Job: {job.id}",
177+
"started_image_processing": f"Dataset: {dataset.id}, Job: {job.job_id}",
178178
"ingest_to_reupload_queue": {
179179
"epoch": 1,
180180
"total": 5,
@@ -204,7 +204,7 @@ def test_dataset_append_async_with_1_bad_url(dataset: Dataset):
204204
status = job.status()
205205
status["message"]["PayloadUrl"] = ""
206206
assert status == {
207-
"job_id": f"{job.id}",
207+
"job_id": f"{job.job_id}",
208208
"status": "Errored",
209209
"message": {
210210
"PayloadUrl": "",
@@ -220,7 +220,7 @@ def test_dataset_append_async_with_1_bad_url(dataset: Dataset):
220220
"datasetId": f"{dataset.id}",
221221
"processed": 5,
222222
},
223-
"started_image_processing": f"Dataset: {dataset.id}, Job: {job.id}",
223+
"started_image_processing": f"Dataset: {dataset.id}, Job: {job.job_id}",
224224
},
225225
}
226226
# The error is fairly detailed and subject to change. What's important is we surface which URLs failed.
@@ -286,7 +286,7 @@ def test_annotate_async(dataset: Dataset):
286286
)
287287
job.sleep_until_complete()
288288
assert job.status() == {
289-
"job_id": job.id,
289+
"job_id": job.job_id,
290290
"status": "Completed",
291291
"message": {
292292
"annotation_upload": {
@@ -321,7 +321,7 @@ def test_annotate_async_with_error(dataset: Dataset):
321321
job.sleep_until_complete()
322322

323323
assert job.status() == {
324-
"job_id": job.id,
324+
"job_id": job.job_id,
325325
"status": "Completed",
326326
"message": {
327327
"annotation_upload": {

tests/test_indexing.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from nucleus.job import AsyncJob
12
import pytest
23

34
from .helpers import (
@@ -42,7 +43,11 @@ def test_index_integration(dataset):
4243
create_response = dataset.create_custom_index(
4344
[signed_embeddings_url], embedding_dim=3
4445
)
45-
assert JOB_ID_KEY in create_response
46+
job = AsyncJob.from_json(create_response, client="Nucleus Client")
47+
assert job.job_id
48+
assert job.job_last_known_status
49+
assert job.job_type
50+
assert job.job_creation_time
4651
assert MESSAGE_KEY in create_response
4752
job_id = create_response[JOB_ID_KEY]
4853

tests/test_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ def test_job_creation_and_listing(CLIENT):
2828
jobs = CLIENT.list_jobs()
2929

3030
for job in jobs:
31-
assert eval(print(job)) == job
31+
assert eval(str(job)) == job

tests/test_prediction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test_mixed_pred_upload_async(model_run: ModelRun):
290290
job.sleep_until_complete()
291291

292292
assert job.status() == {
293-
"job_id": job.id,
293+
"job_id": job.job_id,
294294
"status": "Completed",
295295
"message": {
296296
"prediction_upload": {
@@ -328,7 +328,7 @@ def test_mixed_pred_upload_async_with_error(model_run: ModelRun):
328328
job.sleep_until_complete()
329329

330330
assert job.status() == {
331-
"job_id": job.id,
331+
"job_id": job.job_id,
332332
"status": "Completed",
333333
"message": {
334334
"prediction_upload": {

0 commit comments

Comments
 (0)