Skip to content

Commit 906060f

Browse files
authored
Merge pull request #82 from scaleapi/da-get-model
Get model + better error messages for segmentation upload
2 parents c25d271 + 1d6e1fe commit 906060f

File tree

7 files changed

+46
-16
lines changed

7 files changed

+46
-16
lines changed

nucleus/__init__.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,11 @@ def list_models(self) -> List[Model]:
181181

182182
return [
183183
Model(
184-
model["id"],
185-
model["name"],
186-
model["ref_id"],
187-
model["metadata"],
188-
self,
184+
model_id=model["id"],
185+
name=model["name"],
186+
reference_id=model["ref_id"],
187+
metadata=model["metadata"] or None,
188+
client=self,
189189
)
190190
for model in model_objects["models"]
191191
]
@@ -231,6 +231,19 @@ def get_dataset(self, dataset_id: str) -> Dataset:
231231
"""
232232
return Dataset(dataset_id, self)
233233

234+
def get_model(self, model_id: str) -> Model:
235+
"""
236+
Fetched a model for a given id
237+
:param model_id: internally controlled dataset_id
238+
:return: model
239+
"""
240+
payload = self.make_request(
241+
payload={},
242+
route=f"model/{model_id}",
243+
requests_command=requests.get,
244+
)
245+
return Model.from_json(payload=payload, client=self)
246+
234247
def get_model_run(self, model_run_id: str, dataset_id: str) -> ModelRun:
235248
"""
236249
Fetches a model_run for given id

nucleus/annotation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,12 @@ def to_payload(self) -> dict:
309309
}
310310

311311

312-
def check_all_annotation_paths_remote(
312+
def check_all_mask_paths_remote(
313313
annotations: Sequence[Union[Annotation]],
314314
):
315315
for annotation in annotations:
316316
if hasattr(annotation, MASK_URL_KEY):
317317
if is_local_path(getattr(annotation, MASK_URL_KEY)):
318318
raise ValueError(
319-
f"Found an annotation with a local path, which cannot be uploaded asynchronously. Use a remote path instead. {annotation}"
319+
f"Found an annotation with a local path, which is not currently supported. Use a remote path instead. {annotation}"
320320
)

nucleus/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .annotation import (
1414
Annotation,
1515
CuboidAnnotation,
16-
check_all_annotation_paths_remote,
16+
check_all_mask_paths_remote,
1717
)
1818
from .constants import (
1919
DATASET_ITEM_IDS_KEY,
@@ -171,9 +171,9 @@ def annotate(
171171
if any((isinstance(ann, CuboidAnnotation) for ann in annotations)):
172172
raise NotImplementedError("Cuboid annotations not yet supported")
173173

174-
if asynchronous:
175-
check_all_annotation_paths_remote(annotations)
174+
check_all_mask_paths_remote(annotations)
176175

176+
if asynchronous:
177177
request_id = serialize_and_write_to_presigned_url(
178178
annotations, self.id, self._client
179179
)

nucleus/model.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,26 @@ def __repr__(self):
3232
return f"Model(model_id='{self.id}', name='{self.name}', reference_id='{self.reference_id}', metadata={self.metadata}, client={self._client})"
3333

3434
def __eq__(self, other):
35-
return self.id == other.id
35+
return (
36+
(self.id == other.id)
37+
and (self.name == other.name)
38+
and (self.metadata == other.metadata)
39+
and (self._client == other._client)
40+
)
3641

3742
def __hash__(self):
3843
return hash(self.id)
3944

45+
@classmethod
46+
def from_json(cls, payload: dict, client):
47+
return cls(
48+
model_id=payload["id"],
49+
name=payload["name"],
50+
reference_id=payload["ref_id"],
51+
metadata=payload["metadata"] or None,
52+
client=client,
53+
)
54+
4055
def create_run(
4156
self,
4257
name: str,

nucleus/model_run.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from typing import Dict, List, Optional, Type, Union
2-
32
import requests
4-
5-
from nucleus.annotation import check_all_annotation_paths_remote
3+
from nucleus.annotation import check_all_mask_paths_remote
64
from nucleus.job import AsyncJob
75
from nucleus.utils import serialize_and_write_to_presigned_url
86

@@ -108,7 +106,7 @@ def predict(
108106
}
109107
"""
110108
if asynchronous:
111-
check_all_annotation_paths_remote(annotations)
109+
check_all_mask_paths_remote(annotations)
112110

113111
request_id = serialize_and_write_to_presigned_url(
114112
annotations, self._dataset_id, self._client

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.1.12"
24+
version = "0.1.13"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/test_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def test_model_creation_and_listing(CLIENT, dataset):
6464
# List the models
6565
ms = CLIENT.list_models()
6666

67+
# Get a model
68+
m = CLIENT.get_model(model.id)
69+
assert m == model
70+
6771
assert model in ms
6872
assert list(set(ms) - set(models_before))[0] == model
6973

0 commit comments

Comments
 (0)