Skip to content

Commit 91dad3c

Browse files
author
Diego Ardila
committed
Tests pass locally
1 parent a1f87ef commit 91dad3c

File tree

6 files changed

+52
-9
lines changed

6 files changed

+52
-9
lines changed

nucleus/__init__.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,13 @@ def list_models(self) -> List[Model]:
181181

182182
return [
183183
Model(
184-
model["id"],
185-
model["name"],
186-
model["ref_id"],
187-
model["metadata"],
188-
self,
184+
model_id=model["id"],
185+
name=model["name"],
186+
reference_id=model["ref_id"],
187+
metadata=None
188+
if model["metadata"] == {}
189+
else model["metadata"],
190+
client=self,
189191
)
190192
for model in model_objects["models"]
191193
]
@@ -231,6 +233,19 @@ def get_dataset(self, dataset_id: str) -> Dataset:
231233
"""
232234
return Dataset(dataset_id, self)
233235

236+
def get_model(self, model_id: str) -> Model:
237+
"""
238+
Fetched a model for a given id
239+
:param model_id: internally controlled dataset_id
240+
:return: model
241+
"""
242+
payload = self.make_request(
243+
payload={},
244+
route=f"model/{model_id}",
245+
requests_command=requests.get,
246+
)
247+
return Model.from_json(payload=payload, client=self)
248+
234249
def get_model_run(self, model_run_id: str, dataset_id: str) -> ModelRun:
235250
"""
236251
Fetches a model_run for given id

nucleus/annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,5 +314,5 @@ def check_all_annotation_paths_remote(
314314
if hasattr(annotation, MASK_URL_KEY):
315315
if is_local_path(getattr(annotation, MASK_URL_KEY)):
316316
raise ValueError(
317-
f"Found an annotation with a local path, which cannot be uploaded asynchronously. Use a remote path instead. {annotation}"
317+
f"Found an annotation with a local path, which is not currently supported. Use a remote path instead. {annotation}"
318318
)

nucleus/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ def annotate(
171171
if any((isinstance(ann, CuboidAnnotation) for ann in annotations)):
172172
raise NotImplementedError("Cuboid annotations not yet supported")
173173

174-
if asynchronous:
175-
check_all_annotation_paths_remote(annotations)
174+
check_all_annotation_paths_remote(annotations)
176175

176+
if asynchronous:
177177
request_id = serialize_and_write_to_presigned_url(
178178
annotations, self.id, self._client
179179
)

nucleus/model.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,28 @@ def __repr__(self):
3232
return f"Model(model_id='{self.id}', name='{self.name}', reference_id='{self.reference_id}', metadata={self.metadata}, client={self._client})"
3333

3434
def __eq__(self, other):
35-
return self.id == other.id
35+
return (
36+
(self.id == other.id)
37+
and (self.name == other.name)
38+
and (self.metadata == other.metadata)
39+
and (self._client == other._client)
40+
)
3641

3742
def __hash__(self):
3843
return hash(self.id)
3944

45+
@classmethod
46+
def from_json(cls, payload: dict, client):
47+
return cls(
48+
model_id=payload["id"],
49+
name=payload["name"],
50+
reference_id=payload["ref_id"],
51+
metadata=None
52+
if payload["metadata"] == {}
53+
else payload["metadata"],
54+
client=client,
55+
)
56+
4057
def create_run(
4158
self,
4259
name: str,

tests/test_dataset.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ def dataset(CLIENT):
7676
assert response == {"message": "Beginning dataset deletion..."}
7777

7878

79+
def test_upload_nonsense(dataset):
80+
response = dataset.append(
81+
[DatasetItem(image_location="https://fake.com/image.jpeg")]
82+
)
83+
print(response)
84+
85+
7986
def make_dataset_items():
8087
ds_items_with_metadata = []
8188
for i, url in enumerate(TEST_IMG_URLS):

tests/test_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def test_model_creation_and_listing(CLIENT, dataset):
6464
# List the models
6565
ms = CLIENT.list_models()
6666

67+
# Get a model
68+
m = CLIENT.get_model(model.id)
69+
assert m == model
70+
6771
assert model in ms
6872
assert list(set(ms) - set(models_before))[0] == model
6973

0 commit comments

Comments
 (0)