Skip to content

Commit 1c74466

Browse files
committed
lint and up version
1 parent b2bfa5d commit 1c74466

File tree

7 files changed

+112
-1
lines changed

7 files changed

+112
-1
lines changed

nucleus/__init__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,40 @@ def predictions_loc(self, model_run_id: str, dataset_item_id: str):
951951
{}, f"modelRun/{model_run_id}/loc/{dataset_item_id}", requests.get
952952
)
953953

954+
def prediction_loc(self, model_run_id: str, prediction_id: str):
955+
"""
956+
Returns info for single Prediction by its id.
957+
:param model_run_id: id of the model run
958+
:param prediction_id: internally controlled id for model predictions.
959+
:return:
960+
{
961+
"ref_id": Reference id of dataset item associated with this prediction
962+
"prediction": BoxPrediction | PolygonPrediction | CuboidPrediction
963+
}
964+
"""
965+
return self.make_request(
966+
{},
967+
f"modelRun/{model_run_id}/prediction/loc/{prediction_id}",
968+
requests.get,
969+
)
970+
971+
def ground_truth_loc(self, dataset_id: str, ground_truth_id: str):
972+
"""
973+
Returns info for single Prediction by its id.
974+
:param dataset_id: id of the dataset
975+
:param ground_truth_id: internally controlled id for ground truth annotations.
976+
:return:
977+
{
978+
"ref_id": Reference id of dataset item associated with this prediction
979+
"ground_truth": BoxAnnotation | PolygonAnnotation | CuboidAnnotation
980+
}
981+
"""
982+
return self.make_request(
983+
{},
984+
f"dataset/{dataset_id}/groundTruth/loc/{ground_truth_id}",
985+
requests.get,
986+
)
987+
954988
def create_slice(self, dataset_id: str, payload: dict) -> Slice:
955989
"""
956990
Creates a slice from items already present in a dataset.

nucleus/dataset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DEFAULT_ANNOTATION_UPDATE_MODE,
2323
EXPORTED_ROWS,
2424
NAME_KEY,
25+
REFERENCE_ID_KEY,
2526
REFERENCE_IDS_KEY,
2627
REQUEST_ID_KEY,
2728
AUTOTAG_SCORE_THRESHOLD,
@@ -385,6 +386,28 @@ def loc(self, dataset_item_id: str) -> dict:
385386
response = self._client.dataitem_loc(self.id, dataset_item_id)
386387
return format_dataset_item_response(response)
387388

389+
def ground_truth_loc(self, ground_truth_id: str):
390+
"""
391+
Returns info for single ground truth Annotation by its id.
392+
:param prediction_id: internally controlled id for ground truth annotations.
393+
:return:
394+
{
395+
"ref_id": Reference id of dataset item associated with this annotation
396+
"groundTruth": BoxAnnotation | PolygonAnnotation | CuboidAnnotation
397+
}
398+
"""
399+
response = self._client.ground_truth_loc(self.id, ground_truth_id)
400+
annotation_type = response["groundTruth"].keys()[0]
401+
return {
402+
"ref_id": response["ref_id"],
403+
"ground_truth": Annotation.from_json(
404+
{
405+
REFERENCE_ID_KEY: response["ref_id"],
406+
**response["groundTruth"][annotation_type],
407+
}
408+
),
409+
}
410+
388411
def create_slice(
389412
self,
390413
name: str,

nucleus/model_run.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,26 @@ def loc(self, dataset_item_id: str):
160160
)
161161
return self._format_prediction_response(response)
162162

163+
def prediction_loc(self, prediction_id: str):
164+
"""
165+
Returns info for single Prediction by its id.
166+
:param prediction_id: internally controlled id for model predictions.
167+
:return:
168+
{
169+
"ref_id": Reference id of dataset item associated with this prediction
170+
"prediction": BoxPrediction | PolygonPrediction | CuboidPrediction
171+
}
172+
"""
173+
response = self._client.prediction_loc(
174+
self.model_run_id, prediction_id
175+
)
176+
return {
177+
"ref_id": response["ref_id"],
178+
"prediction": self._format_prediction_response(
179+
{"predictions": [response["prediction"]]}
180+
),
181+
}
182+
163183
def ungrouped_export(self):
164184
json_response = self._client.make_request(
165185
payload={},

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.1.21"
24+
version = "0.1.22"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
TEST_PROJECT_ID = "60b699d70f139e002dd31bfc"
1515

1616
DATASET_WITH_AUTOTAG = "ds_c4dgj702e2vjft7m9xa0"
17+
DATASET_WITH_PREDICTIONS = "ds_c4283fxxjfvg05rmm5p0"
18+
DATASET_WITH_PREDICTIONS_MODEL_RUN_ID = "run_c4283gpxjfvg0ehz6gjg"
19+
DATASET_WITH_PREDICTIONS_SAMPLE_PREDICTION_ID = "pred_c4283gp2s2b00esdarq0"
20+
DATASET_WITH_GROUND_TRUTH = "ds_c3pz51mjy60g09hfzhfg"
21+
DATASET_WITH_GROUND_TRUTH_SAMPLE_GROUND_TRUTH_ID = "ann_c3pz524k221g0css4ns0"
1722
NUCLEUS_PYTEST_USER_ID = "60ad648c85db770026e9bf77"
1823

1924

tests/test_dataset.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from nucleus.job import AsyncJob, JobError
3434

3535
from .helpers import (
36+
DATASET_WITH_GROUND_TRUTH,
37+
DATASET_WITH_GROUND_TRUTH_SAMPLE_GROUND_TRUTH_ID,
3638
LOCAL_FILENAME,
3739
TEST_BOX_ANNOTATIONS,
3840
TEST_DATASET_NAME,
@@ -100,6 +102,15 @@ def make_dataset_items():
100102
return ds_items_with_metadata
101103

102104

105+
def test_dataset_ground_truth_loc(CLIENT):
106+
ds = CLIENT.get_dataset(DATASET_WITH_GROUND_TRUTH)
107+
response = ds.ground_truth_loc(
108+
DATASET_WITH_GROUND_TRUTH_SAMPLE_GROUND_TRUTH_ID
109+
)
110+
assert response["ref_id"] is not None
111+
assert response["ground_truth"] is not None
112+
113+
103114
def test_dataset_create_and_delete(CLIENT):
104115
# Creation
105116
ds = CLIENT.create_dataset(TEST_DATASET_NAME)

tests/test_prediction.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
from nucleus.job import AsyncJob
2+
import os
23
import pytest
34
import time
45
from .helpers import (
6+
DATASET_WITH_PREDICTIONS,
7+
DATASET_WITH_PREDICTIONS_MODEL_RUN_ID,
8+
DATASET_WITH_PREDICTIONS_SAMPLE_PREDICTION_ID,
9+
NUCLEUS_PYTEST_USER_ID,
510
TEST_DATASET_NAME,
611
TEST_MODEL_NAME,
712
TEST_MODEL_RUN,
@@ -76,6 +81,19 @@ def model_run(CLIENT):
7681
assert response == {}
7782

7883

84+
def test_pred_loc(CLIENT):
85+
if NUCLEUS_PYTEST_USER_ID in os.environ["NUCLEUS_PYTEST_USER_ID"]:
86+
run = CLIENT.get_model_run(
87+
DATASET_WITH_PREDICTIONS_MODEL_RUN_ID, DATASET_WITH_PREDICTIONS
88+
)
89+
response = run.prediction_loc(
90+
DATASET_WITH_PREDICTIONS_SAMPLE_PREDICTION_ID
91+
)
92+
93+
assert response["ref_id"] is not None
94+
assert response["prediction"] is not None
95+
96+
7997
def test_box_pred_upload(model_run):
8098
prediction = BoxPrediction(**TEST_BOX_PREDICTIONS[0])
8199
response = model_run.predict(annotations=[prediction])

0 commit comments

Comments
 (0)