Skip to content

Commit 759af4d

Browse files
authored
Merge branch 'master' into support_cuboid_annotations
2 parents 31a802d + 906060f commit 759af4d

File tree

8 files changed

+67
-22
lines changed

8 files changed

+67
-22
lines changed

nucleus/__init__.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,11 @@ def list_models(self) -> List[Model]:
181181

182182
return [
183183
Model(
184-
model["id"],
185-
model["name"],
186-
model["ref_id"],
187-
model["metadata"],
188-
self,
184+
model_id=model["id"],
185+
name=model["name"],
186+
reference_id=model["ref_id"],
187+
metadata=model["metadata"] or None,
188+
client=self,
189189
)
190190
for model in model_objects["models"]
191191
]
@@ -231,6 +231,19 @@ def get_dataset(self, dataset_id: str) -> Dataset:
231231
"""
232232
return Dataset(dataset_id, self)
233233

234+
def get_model(self, model_id: str) -> Model:
235+
"""
236+
Fetched a model for a given id
237+
:param model_id: internally controlled dataset_id
238+
:return: model
239+
"""
240+
payload = self.make_request(
241+
payload={},
242+
route=f"model/{model_id}",
243+
requests_command=requests.get,
244+
)
245+
return Model.from_json(payload=payload, client=self)
246+
234247
def get_model_run(self, model_run_id: str, dataset_id: str) -> ModelRun:
235248
"""
236249
Fetches a model_run for given id

nucleus/annotation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ class Annotation:
3737
item_id: Optional[str] = None
3838

3939
def _check_ids(self):
40-
if bool(self.reference_id) == bool(self.item_id):
40+
if self.reference_id and self.item_id:
41+
self.item_id = None # Prefer reference id to item id.
42+
if not (self.reference_id or self.item_id):
4143
raise Exception(
4244
"You must specify either a reference_id or an item_id for an annotation."
4345
)
@@ -307,12 +309,12 @@ def to_payload(self) -> dict:
307309
}
308310

309311

310-
def check_all_annotation_paths_remote(
312+
def check_all_mask_paths_remote(
311313
annotations: Sequence[Union[Annotation]],
312314
):
313315
for annotation in annotations:
314316
if hasattr(annotation, MASK_URL_KEY):
315317
if is_local_path(getattr(annotation, MASK_URL_KEY)):
316318
raise ValueError(
317-
f"Found an annotation with a local path, which cannot be uploaded asynchronously. Use a remote path instead. {annotation}"
319+
f"Found an annotation with a local path, which is not currently supported. Use a remote path instead. {annotation}"
318320
)

nucleus/dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
from .annotation import (
1414
Annotation,
15-
check_all_annotation_paths_remote,
15+
CuboidAnnotation,
16+
check_all_mask_paths_remote,
1617
)
1718
from .constants import (
1819
DATASET_ITEM_IDS_KEY,
@@ -167,9 +168,9 @@ def annotate(
167168
"ignored_items": int,
168169
}
169170
"""
170-
if asynchronous:
171-
check_all_annotation_paths_remote(annotations)
171+
check_all_mask_paths_remote(annotations)
172172

173+
if asynchronous:
173174
request_id = serialize_and_write_to_presigned_url(
174175
annotations, self.id, self._client
175176
)

nucleus/model.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,26 @@ def __repr__(self):
3333
return f"Model(model_id='{self.id}', name='{self.name}', reference_id='{self.reference_id}', metadata={self.metadata}, client={self._client})"
3434

3535
def __eq__(self, other):
36-
return self.id == other.id
36+
return (
37+
(self.id == other.id)
38+
and (self.name == other.name)
39+
and (self.metadata == other.metadata)
40+
and (self._client == other._client)
41+
)
3742

3843
def __hash__(self):
3944
return hash(self.id)
4045

46+
@classmethod
47+
def from_json(cls, payload: dict, client):
48+
return cls(
49+
model_id=payload["id"],
50+
name=payload["name"],
51+
reference_id=payload["ref_id"],
52+
metadata=payload["metadata"] or None,
53+
client=client,
54+
)
55+
4156
def create_run(
4257
self,
4358
name: str,

nucleus/model_run.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Dict, List, Optional, Type, Union
2-
3-
from nucleus.annotation import check_all_annotation_paths_remote
2+
import requests
3+
from nucleus.annotation import check_all_mask_paths_remote
44
from nucleus.job import AsyncJob
55
from nucleus.utils import serialize_and_write_to_presigned_url
66

@@ -113,7 +113,7 @@ def predict(
113113
}
114114
"""
115115
if asynchronous:
116-
check_all_annotation_paths_remote(annotations)
116+
check_all_mask_paths_remote(annotations)
117117

118118
request_id = serialize_and_write_to_presigned_url(
119119
annotations, self._dataset_id, self._client
@@ -162,6 +162,16 @@ def loc(self, dataset_item_id: str):
162162
)
163163
return self._format_prediction_response(response)
164164

165+
def ungrouped_export(self):
166+
json_response = self._client.make_request(
167+
payload={},
168+
route=f"modelRun/{self.model_run_id}/ungrouped",
169+
requests_command=requests.get,
170+
)
171+
return self._format_prediction_response(
172+
{ANNOTATIONS_KEY: json_response}
173+
)
174+
165175
def _format_prediction_response(
166176
self, response: dict
167177
) -> Union[

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.1.11"
24+
version = "0.1.13"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/test_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def test_model_creation_and_listing(CLIENT, dataset):
6464
# List the models
6565
ms = CLIENT.list_models()
6666

67+
# Get a model
68+
m = CLIENT.get_model(model.id)
69+
assert m == model
70+
6771
assert model in ms
6872
assert list(set(ms) - set(models_before))[0] == model
6973

tests/test_prediction.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_repr(test_object: any):
5050
def model_run(CLIENT):
5151
ds = CLIENT.create_dataset(TEST_DATASET_NAME)
5252
ds_items = []
53-
for url in TEST_IMG_URLS:
53+
for url in TEST_IMG_URLS[:2]:
5454
ds_items.append(
5555
DatasetItem(
5656
image_location=url,
@@ -246,7 +246,7 @@ def test_polygon_pred_upload_ignore(model_run):
246246
)
247247

248248

249-
def test_mixed_pred_upload(model_run):
249+
def test_mixed_pred_upload(model_run: ModelRun):
250250
prediction_semseg = SegmentationPrediction.from_json(
251251
TEST_SEGMENTATION_PREDICTIONS[0]
252252
)
@@ -262,15 +262,15 @@ def test_mixed_pred_upload(model_run):
262262
assert response["predictions_processed"] == 3
263263
assert response["predictions_ignored"] == 0
264264

265-
response_refloc = model_run.refloc(prediction_polygon.reference_id)
265+
all_predictions = model_run.ungrouped_export()
266266
assert_box_prediction_matches_dict(
267-
response_refloc["box"][0], TEST_BOX_PREDICTIONS[0]
267+
all_predictions["box"][0], TEST_BOX_PREDICTIONS[0]
268268
)
269269
assert_polygon_prediction_matches_dict(
270-
response_refloc["polygon"][0], TEST_POLYGON_PREDICTIONS[0]
270+
all_predictions["polygon"][0], TEST_POLYGON_PREDICTIONS[0]
271271
)
272272
assert_segmentation_annotation_matches_dict(
273-
response_refloc["segmentation"][0], TEST_SEGMENTATION_PREDICTIONS[0]
273+
all_predictions["segmentation"][0], TEST_SEGMENTATION_PREDICTIONS[0]
274274
)
275275

276276

0 commit comments

Comments
 (0)