Skip to content

Commit 255a642

Browse files
added equality operator and unit tests
1 parent c0023aa commit 255a642

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

nucleus/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,7 @@ def list_models(self) -> List[Model]:
138138
"""
139139
model_objects = self._make_request({}, "models/", requests.get)
140140

141-
try:
142-
models = [Model(model["id"], model["name"], model["ref_id"], model["metadata"], self) for model in model_objects["models"]]
143-
return models
144-
except e:
145-
return []
141+
return [Model(model["id"], model["name"], model["ref_id"], model["metadata"], self) for model in model_objects["models"]]
146142

147143
def list_datasets(self) -> Dict[str, Union[str, List[str]]]:
148144
"""

nucleus/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def __init__(
2727
def __repr__(self):
2828
return f'Model(model_id={self.id}, name={self.name}, reference_id={self.reference_id}, metadata={self.metadata}, client={self._client})'
2929

30+
def __eq__(self, other):
31+
return self.id == other.id
32+
3033
def create_run(
3134
self,
3235
name: str,

tests/test_models.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from pathlib import Path
2+
import pytest
3+
from nucleus import Dataset, DatasetItem, UploadResponse, Model, ModelRun
4+
from nucleus.constants import (
5+
NEW_ITEMS,
6+
UPDATED_ITEMS,
7+
IGNORED_ITEMS,
8+
ERROR_ITEMS,
9+
ERROR_PAYLOAD,
10+
DATASET_ID_KEY,
11+
)
12+
13+
TEST_MODEL_NAME = '[PyTest] Test Model'
14+
TEST_REFERENCE_ID = '[PyTest] Test Model'
15+
TEST_METADATA = {
16+
'key': 'value'
17+
}
18+
TEST_MODEL_RUN_NAME = '[PyTest] Test ModelRun'
19+
TEST_DATASET_NAME = '[PyTest] Test Dataset'
20+
TEST_SLICE_NAME = '[PyTest] Test Slice'
21+
TEST_IMG_URLS = [
22+
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/6dd63871-831611a6.jpg",
23+
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/82c1005c-e2d1d94f.jpg",
24+
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/7f2e1814-6591087d.jpg",
25+
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/06924f46-1708b96f.jpg",
26+
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/89b42832-10d662f4.jpg",
27+
]
28+
29+
@pytest.fixture()
30+
def dataset(CLIENT):
31+
ds = CLIENT.create_dataset(TEST_DATASET_NAME)
32+
yield ds
33+
34+
CLIENT.delete_dataset(ds.id)
35+
36+
def test_model_creation_and_listing(CLIENT):
37+
# Creation
38+
m = CLIENT.add_model(TEST_MODEL_NAME, TEST_REFERENCE_ID, TEST_METADATA)
39+
m_run = m.create_run(TEST_MODEL_RUN_NAME, TEST_DATASET, TEST_PREDS, TEST_METADATA)
40+
41+
assert isinstance(m, Model)
42+
assert isinstance(m_run, ModelRun)
43+
44+
# List
45+
ms = CLIENT.list_models()
46+
47+
assert m in ms
48+
49+
CLIENT.delete_model(m.id)
50+
CLIENT.delete_model_run(m_run.model_run_id)

0 commit comments

Comments
 (0)