Skip to content

Commit 74e619f

Browse files
authored
Merge branch 'master' into sasha/segmentation
2 parents 369672c + b928eaf commit 74e619f

File tree

6 files changed

+90
-3
lines changed

6 files changed

+90
-3
lines changed

conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import nucleus
1818
from nucleus.constants import SUCCESS_STATUS_CODES
1919

20+
from tests.helpers import TEST_DATASET_NAME, TEST_DATASET_ITEMS
2021

2122
assert 'NUCLEUS_PYTEST_API_KEY' in os.environ, \
2223
"You must set the 'NUCLEUS_PYTEST_API_KEY' environment variable to a valid " \
@@ -55,3 +56,11 @@ def _make_request_patch(
5556

5657
monkeypatch_session.setattr(client, "_make_request", _make_request_patch)
5758
return client
59+
60+
@pytest.fixture()
61+
def dataset(CLIENT):
62+
ds = CLIENT.create_dataset(TEST_DATASET_NAME)
63+
ds.append(TEST_DATASET_ITEMS)
64+
yield ds
65+
66+
CLIENT.delete_dataset(ds.id)

nucleus/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,14 @@ def __init__(self, api_key: str, use_notebook: bool = False):
138138
if use_notebook:
139139
self.tqdm_bar = tqdm_notebook.tqdm
140140

141-
def list_models(self) -> List[str]:
141+
def list_models(self) -> List[Model]:
142142
"""
143143
Lists available models in your repo.
144144
:return: model_ids
145145
"""
146-
# TODO implement API
147-
raise NotImplementedError
146+
model_objects = self._make_request({}, "models/", requests.get)
147+
148+
return [Model(model["id"], model["name"], model["ref_id"], model["metadata"], self) for model in model_objects["models"]]
148149

149150
def list_datasets(self) -> Dict[str, Union[str, List[str]]]:
150151
"""

nucleus/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ def __init__(
2424
self.metadata = metadata
2525
self._client = client
2626

27+
def __repr__(self):
28+
return f'Model(model_id={self.id}, name={self.name}, reference_id={self.reference_id}, metadata={self.metadata}, client={self._client})'
29+
30+
def __eq__(self, other):
31+
return self.id == other.id
32+
33+
def __hash__(self):
34+
return hash(self.id)
35+
2736
def create_run(
2837
self,
2938
name: str,

tests/__init__.py

Whitespace-only changes.

tests/helpers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22
from urllib.parse import urlparse
33
import boto3
4+
from nucleus import DatasetItem, BoxPrediction
45

56
PRESIGN_EXPIRY_SECONDS = 60 * 60 * 24 * 2 # 2 days
67

@@ -9,13 +10,31 @@
910
TEST_MODEL_RUN = "[PyTest] Test Model Run"
1011
TEST_DATASET_NAME = "[PyTest] Test Dataset"
1112
TEST_SLICE_NAME = "[PyTest] Test Slice"
13+
14+
TEST_MODEL_NAME = '[PyTest] Test Model Name'
15+
TEST_MODEL_REFERENCE = '[PyTest] Test Model Reference'
16+
TEST_MODEL_RUN = '[PyTest] Test Model Run Reference'
17+
TEST_DATASET_NAME = '[PyTest] Test Dataset'
18+
TEST_SLICE_NAME = '[PyTest] Test Slice'
1219
TEST_IMG_URLS = [
1320
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/6dd63871-831611a6.jpg",
1421
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/82c1005c-e2d1d94f.jpg",
1522
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/7f2e1814-6591087d.jpg",
1623
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/06924f46-1708b96f.jpg",
1724
"s3://scaleapi-attachments/BDD/BDD/bdd100k/images/100k/train/89b42832-10d662f4.jpg",
1825
]
26+
TEST_DATASET_ITEMS = [
27+
DatasetItem(TEST_IMG_URLS[0], '1'),
28+
DatasetItem(TEST_IMG_URLS[1], '2'),
29+
DatasetItem(TEST_IMG_URLS[2], '3'),
30+
DatasetItem(TEST_IMG_URLS[3], '4')
31+
]
32+
TEST_PREDS = [
33+
BoxPrediction('[Pytest Box Prediction 1]', 0, 0, 100, 100, '1'),
34+
BoxPrediction('[Pytest Box Prediction 2]', 0, 0, 100, 100, '2'),
35+
BoxPrediction('[Pytest Box Prediction 3]', 0, 0, 100, 100, '3'),
36+
BoxPrediction('[Pytest Box Prediction 4]', 0, 0, 100, 100, '4')
37+
]
1938

2039

2140
def get_signed_url(url):

tests/test_models.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from pathlib import Path
2+
import pytest
3+
from nucleus import (
4+
Dataset,
5+
DatasetItem,
6+
UploadResponse,
7+
Model,
8+
ModelRun,
9+
BoxPrediction
10+
)
11+
from nucleus.constants import (
12+
NEW_ITEMS,
13+
UPDATED_ITEMS,
14+
IGNORED_ITEMS,
15+
ERROR_ITEMS,
16+
ERROR_PAYLOAD,
17+
DATASET_ID_KEY,
18+
)
19+
from .helpers import (
20+
TEST_MODEL_NAME,
21+
TEST_MODEL_REFERENCE,
22+
TEST_MODEL_RUN,
23+
TEST_PREDS
24+
)
25+
26+
def test_model_creation_and_listing(CLIENT, dataset):
27+
models_before = CLIENT.list_models()
28+
29+
# Creation
30+
model = CLIENT.add_model(TEST_MODEL_NAME, TEST_MODEL_REFERENCE)
31+
m_run = model.create_run(TEST_MODEL_RUN, dataset, TEST_PREDS)
32+
m_run.commit()
33+
34+
assert isinstance(model, Model)
35+
assert isinstance(m_run, ModelRun)
36+
37+
# List the models
38+
ms = CLIENT.list_models()
39+
40+
assert model in ms
41+
assert list(set(ms) - set(models_before))[0] == model
42+
43+
# Delete the model
44+
CLIENT.delete_model(model.id)
45+
ms = CLIENT.list_models()
46+
47+
assert model not in ms
48+
assert ms == models_before
49+

0 commit comments

Comments
 (0)