Skip to content

Commit b928eaf

Browse files
Merge pull request #34 from scaleapi/list-models-endpoint
added ability to list models
2 parents da7e7f8 + 2d8cb83 commit b928eaf

File tree

6 files changed

+86
-5
lines changed

6 files changed

+86
-5
lines changed

conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import nucleus
1818
from nucleus.constants import SUCCESS_STATUS_CODES
1919

20+
from tests.helpers import TEST_DATASET_NAME, TEST_DATASET_ITEMS
2021

2122
assert 'NUCLEUS_PYTEST_API_KEY' in os.environ, \
2223
"You must set the 'NUCLEUS_PYTEST_API_KEY' environment variable to a valid " \
@@ -55,3 +56,11 @@ def _make_request_patch(
5556

5657
monkeypatch_session.setattr(client, "_make_request", _make_request_patch)
5758
return client
59+
60+
@pytest.fixture()
61+
def dataset(CLIENT):
62+
ds = CLIENT.create_dataset(TEST_DATASET_NAME)
63+
ds.append(TEST_DATASET_ITEMS)
64+
yield ds
65+
66+
CLIENT.delete_dataset(ds.id)

nucleus/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,14 @@ def __init__(self, api_key: str, use_notebook: bool = False):
133133
if use_notebook:
134134
self.tqdm_bar = tqdm_notebook.tqdm
135135

136-
def list_models(self) -> List[str]:
136+
def list_models(self) -> List[Model]:
137137
"""
138138
Lists available models in your repo.
139139
:return: model_ids
140140
"""
141-
# TODO implement API
142-
raise NotImplementedError
141+
model_objects = self._make_request({}, "models/", requests.get)
142+
143+
return [Model(model["id"], model["name"], model["ref_id"], model["metadata"], self) for model in model_objects["models"]]
143144

144145
def list_datasets(self) -> Dict[str, Union[str, List[str]]]:
145146
"""

nucleus/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ def __init__(
2424
self.metadata = metadata
2525
self._client = client
2626

27+
def __repr__(self):
28+
return f'Model(model_id={self.id}, name={self.name}, reference_id={self.reference_id}, metadata={self.metadata}, client={self._client})'
29+
30+
def __eq__(self, other):
31+
return self.id == other.id
32+
33+
def __hash__(self):
34+
return hash(self.id)
35+
2736
def create_run(
2837
self,
2938
name: str,

tests/__init__.py

Whitespace-only changes.

tests/helpers.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from pathlib import Path
2+
from nucleus import DatasetItem, BoxPrediction
23

3-
TEST_MODEL_NAME = '[PyTest] Test Model'
4+
TEST_MODEL_NAME = '[PyTest] Test Model Name'
45
TEST_MODEL_REFERENCE = '[PyTest] Test Model Reference'
5-
TEST_MODEL_RUN = '[PyTest] Test Model Run'
6+
TEST_MODEL_RUN = '[PyTest] Test Model Run Reference'
67
TEST_DATASET_NAME = '[PyTest] Test Dataset'
78
TEST_SLICE_NAME = '[PyTest] Test Slice'
89
TEST_IMG_URLS = [
@@ -12,6 +13,18 @@
1213
's3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/06924f46-1708b96f.jpg',
1314
's3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/89b42832-10d662f4.jpg',
1415
]
16+
TEST_DATASET_ITEMS = [
17+
DatasetItem(TEST_IMG_URLS[0], '1'),
18+
DatasetItem(TEST_IMG_URLS[1], '2'),
19+
DatasetItem(TEST_IMG_URLS[2], '3'),
20+
DatasetItem(TEST_IMG_URLS[3], '4')
21+
]
22+
TEST_PREDS = [
23+
BoxPrediction('[Pytest Box Prediction 1]', 0, 0, 100, 100, '1'),
24+
BoxPrediction('[Pytest Box Prediction 2]', 0, 0, 100, 100, '2'),
25+
BoxPrediction('[Pytest Box Prediction 3]', 0, 0, 100, 100, '3'),
26+
BoxPrediction('[Pytest Box Prediction 4]', 0, 0, 100, 100, '4')
27+
]
1528

1629
def reference_id_from_url(url):
1730
return Path(url).name

tests/test_models.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from pathlib import Path
2+
import pytest
3+
from nucleus import (
4+
Dataset,
5+
DatasetItem,
6+
UploadResponse,
7+
Model,
8+
ModelRun,
9+
BoxPrediction
10+
)
11+
from nucleus.constants import (
12+
NEW_ITEMS,
13+
UPDATED_ITEMS,
14+
IGNORED_ITEMS,
15+
ERROR_ITEMS,
16+
ERROR_PAYLOAD,
17+
DATASET_ID_KEY,
18+
)
19+
from .helpers import (
20+
TEST_MODEL_NAME,
21+
TEST_MODEL_REFERENCE,
22+
TEST_MODEL_RUN,
23+
TEST_PREDS
24+
)
25+
26+
def test_model_creation_and_listing(CLIENT, dataset):
27+
models_before = CLIENT.list_models()
28+
29+
# Creation
30+
model = CLIENT.add_model(TEST_MODEL_NAME, TEST_MODEL_REFERENCE)
31+
m_run = model.create_run(TEST_MODEL_RUN, dataset, TEST_PREDS)
32+
m_run.commit()
33+
34+
assert isinstance(model, Model)
35+
assert isinstance(m_run, ModelRun)
36+
37+
# List the models
38+
ms = CLIENT.list_models()
39+
40+
assert model in ms
41+
assert list(set(ms) - set(models_before))[0] == model
42+
43+
# Delete the model
44+
CLIENT.delete_model(model.id)
45+
ms = CLIENT.list_models()
46+
47+
assert model not in ms
48+
assert ms == models_before
49+

0 commit comments

Comments
 (0)