Skip to content

Commit 64c7bd8

Browse files
gatliphil-scale
andauthored
Speed up ModelCI tests and add annotations as are required by backend (#181)
* Speed up Model CI tests and add a evaluation test with missing predictions * Fix create_annotations bug * Add TODO to remove test * Add annotation dependency to unit_test * Update message Co-authored-by: phil-scale <phil.chen@scale.com>
1 parent c5162e3 commit 64c7bd8

File tree

6 files changed

+181
-52
lines changed

6 files changed

+181
-52
lines changed

nucleus/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
]
3636

3737
import os
38-
import time
3938
import warnings
4039
from typing import Dict, List, Optional, Sequence, Union
4140

@@ -102,6 +101,7 @@
102101
DatasetItemRetrievalError,
103102
ModelCreationError,
104103
ModelRunCreationError,
104+
NoAPIKey,
105105
NotFoundError,
106106
NucleusAPIError,
107107
)
@@ -150,11 +150,11 @@ class NucleusClient:
150150

151151
def __init__(
152152
self,
153-
api_key: str,
153+
api_key: Optional[str] = None,
154154
use_notebook: bool = False,
155155
endpoint: str = None,
156156
):
157-
self.api_key = api_key
157+
self.api_key = self._set_api_key(api_key)
158158
self.tqdm_bar = tqdm.tqdm
159159
if endpoint is None:
160160
self.endpoint = os.environ.get(
@@ -166,7 +166,6 @@ def __init__(
166166
if use_notebook:
167167
self.tqdm_bar = tqdm_notebook.tqdm
168168
self._connection = Connection(self.api_key, self.endpoint)
169-
170169
self.modelci = ModelCI(self.api_key, self.endpoint)
171170

172171
def __repr__(self):
@@ -936,3 +935,13 @@ def handle_bad_response(
936935
self._connection.handle_bad_response(
937936
endpoint, requests_command, requests_response, aiohttp_response
938937
)
938+
939+
def _set_api_key(self, api_key):
940+
"""Fetch API key from environment variable NUCLEUS_API_KEY if not set"""
941+
api_key = (
942+
api_key if api_key else os.environ.get("NUCLEUS_API_KEY", None)
943+
)
944+
if api_key is None:
945+
raise NoAPIKey()
946+
947+
return api_key

nucleus/errors.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,12 @@ def __init__(
6262
message += "\n This likely indicates temporary downtime of the API, please try again in a minute or two"
6363

6464
super().__init__(message)
65+
66+
67+
class NoAPIKey(Exception):
68+
def __init__(
69+
self,
70+
message="You need to pass an API key to the NucleusClient or set the environment variable NUCLEUS_API_KEY",
71+
):
72+
self.message = message
73+
super().__init__(self.message)

tests/modelci/conftest.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,48 @@
22

33
import pytest
44

5-
from tests.helpers import TEST_MODEL_NAME, TEST_SLICE_NAME, get_uuid
5+
from nucleus import BoxAnnotation
6+
from tests.helpers import (
7+
TEST_BOX_ANNOTATIONS,
8+
TEST_MODEL_NAME,
9+
TEST_SLICE_NAME,
10+
get_uuid,
11+
)
12+
from tests.modelci.helpers import create_box_annotations, create_predictions
613
from tests.test_dataset import make_dataset_items
714

815

9-
@pytest.fixture()
16+
@pytest.fixture(scope="module")
17+
def modelci_dataset(CLIENT):
18+
"""SHOULD NOT BE MUTATED IN TESTS. This dataset lives for the whole test module scope."""
19+
ds = CLIENT.create_dataset("[Test Model CI] Dataset", is_scene=False)
20+
yield ds
21+
22+
CLIENT.delete_dataset(ds.id)
23+
24+
25+
@pytest.fixture(scope="module")
26+
def dataset_items(modelci_dataset):
27+
items = make_dataset_items()
28+
modelci_dataset.append(items)
29+
yield items
30+
31+
32+
@pytest.fixture(scope="module")
33+
def slice_items(dataset_items):
34+
yield dataset_items[:2]
35+
36+
37+
@pytest.fixture(scope="module")
38+
def test_slice(modelci_dataset, slice_items):
39+
slc = modelci_dataset.create_slice(
40+
name=TEST_SLICE_NAME,
41+
reference_ids=[item.reference_id for item in slice_items],
42+
)
43+
yield slc
44+
45+
46+
@pytest.fixture(scope="module")
1047
def model(CLIENT):
1148
model_reference = "model_" + str(time.time())
1249
model = CLIENT.create_model(TEST_MODEL_NAME, model_reference)
@@ -15,34 +52,29 @@ def model(CLIENT):
1552
CLIENT.delete_model(model.id)
1653

1754

18-
@pytest.fixture()
19-
def unit_test(CLIENT, dataset):
20-
items = make_dataset_items()
21-
dataset.append(items)
55+
@pytest.fixture(scope="module")
56+
def annotations(modelci_dataset, slice_items):
57+
annotations = create_box_annotations(modelci_dataset, slice_items)
58+
yield annotations
59+
60+
61+
@pytest.fixture(scope="module")
62+
def predictions(model, modelci_dataset, annotations):
63+
predictions = create_predictions(modelci_dataset, model, annotations)
64+
yield predictions
65+
66+
67+
@pytest.fixture(scope="module")
68+
@pytest.mark.usefixtures(
69+
"annotations"
70+
) # Unit test needs to have annotations in the slice
71+
def unit_test(CLIENT, test_slice):
2272
test_name = "unit_test_" + get_uuid() # use uuid to make unique
23-
slc = dataset.create_slice(
24-
name=TEST_SLICE_NAME,
25-
reference_ids=[items[0].reference_id],
26-
)
2773
unit_test = CLIENT.modelci.create_unit_test(
2874
name=test_name,
29-
slice_id=slc.id,
75+
slice_id=test_slice.id,
3076
evaluation_criteria=[CLIENT.modelci.eval_functions.bbox_recall > 0.5],
3177
)
3278
yield unit_test
3379

3480
CLIENT.modelci.delete_unit_test(unit_test.id)
35-
36-
37-
@pytest.fixture()
38-
def test_slice(CLIENT, dataset):
39-
items = make_dataset_items()
40-
dataset.append(items)
41-
slice_name = TEST_SLICE_NAME + f"_{get_uuid()}"
42-
slc = dataset.create_slice(
43-
name=slice_name,
44-
reference_ids=[items[0].reference_id],
45-
)
46-
yield slc
47-
48-
CLIENT.delete_slice(slc.id)

tests/modelci/helpers.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import List
2+
3+
from nucleus import BoxAnnotation, BoxPrediction, Dataset, DatasetItem, Model
4+
5+
6+
def create_box_annotations(
7+
dataset: Dataset, dataset_items: List[DatasetItem]
8+
) -> List[BoxAnnotation]:
9+
annotations = [
10+
BoxAnnotation(
11+
label=f"[Pytest] Box Annotation {ds_item.reference_id}",
12+
x=50 + i * 10,
13+
y=60 + i * 10,
14+
width=70 + i * 10,
15+
height=80 + i * 10,
16+
reference_id=ds_item.reference_id,
17+
annotation_id=f"[Pytest] Box Annotation Annotation Id{i}",
18+
)
19+
for i, ds_item in enumerate(dataset_items)
20+
]
21+
dataset.annotate(annotations)
22+
return annotations
23+
24+
25+
def create_predictions(
26+
dataset: Dataset, model: Model, annotations: List[BoxAnnotation]
27+
) -> List[BoxPrediction]:
28+
predictions = [
29+
BoxPrediction(
30+
label=ann.label,
31+
x=ann.x,
32+
y=ann.y,
33+
width=ann.width,
34+
height=ann.height,
35+
reference_id=ann.reference_id,
36+
confidence=0.1 * i,
37+
)
38+
for i, ann in enumerate(annotations)
39+
]
40+
dataset.upload_predictions(model, predictions)
41+
return predictions

tests/modelci/test_unit_test.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
from tests.helpers import (
66
EVAL_FUNCTION_COMPARISON,
77
EVAL_FUNCTION_THRESHOLD,
8-
TEST_SLICE_NAME,
98
get_uuid,
109
)
11-
from tests.test_dataset import make_dataset_items
1210

1311

1412
def test_unit_test_metric_creation(CLIENT, unit_test):
@@ -42,23 +40,16 @@ def test_list_unit_test(CLIENT, test_slice):
4240
CLIENT.modelci.delete_unit_test(unit_test.id)
4341

4442

45-
def test_unit_test_items(CLIENT, dataset):
46-
# create some dataset_items for the unit test to reference
47-
items = make_dataset_items()
48-
dataset.append(items)
43+
def test_unit_test_items(CLIENT, test_slice, slice_items, annotations):
4944
test_name = "unit_test_" + get_uuid() # use uuid to make unique
50-
slc = dataset.create_slice(
51-
name=TEST_SLICE_NAME,
52-
reference_ids=[item.reference_id for item in items],
53-
)
5445

5546
unit_test = CLIENT.modelci.create_unit_test(
5647
name=test_name,
57-
slice_id=slc.id,
48+
slice_id=test_slice.id,
5849
evaluation_criteria=[CLIENT.modelci.eval_functions.bbox_iou() > 0.5],
5950
)
6051

61-
expected_items_locations = [item.image_location for item in items]
52+
expected_items_locations = [item.image_location for item in slice_items]
6253
actual_items_locations = [
6354
item.image_location for item in unit_test.get_items()
6455
]

tests/modelci/test_unit_test_evaluation.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,73 @@
11
import pytest
22

3-
from nucleus import BoxAnnotation, BoxPrediction
43
from nucleus.job import AsyncJob
54
from nucleus.modelci.unit_test_evaluation import (
65
UnitTestEvaluation,
76
UnitTestItemEvaluation,
87
)
9-
from tests.helpers import (
10-
EVAL_FUNCTION_THRESHOLD,
11-
TEST_BOX_ANNOTATIONS,
12-
TEST_BOX_PREDICTIONS,
13-
)
8+
from tests.helpers import EVAL_FUNCTION_THRESHOLD, get_uuid
9+
from tests.modelci.helpers import create_predictions
1410

1511

1612
@pytest.mark.integration
17-
def test_unit_test_evaluation(CLIENT, dataset, model, unit_test):
18-
annotations = [BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])]
19-
dataset.annotate(annotations=annotations)
20-
predictions = [BoxPrediction(**TEST_BOX_PREDICTIONS[0])]
21-
dataset.upload_predictions(model, predictions)
13+
def test_unit_test_evaluation(
14+
CLIENT, modelci_dataset, model, unit_test, annotations, predictions
15+
):
16+
iou = CLIENT.modelci.eval_functions.bbox_iou
17+
# NOTE: Another criterion is defined in the unit_test fixture
18+
unit_test.add_criterion(iou() > EVAL_FUNCTION_THRESHOLD)
19+
20+
job: AsyncJob = CLIENT.modelci.evaluate_model_on_unit_tests(
21+
model.id, [unit_test.name]
22+
)
23+
job.sleep_until_complete()
2224

25+
criteria = unit_test.get_criteria()
26+
evaluations = unit_test.get_eval_history()
27+
assert isinstance(evaluations, list)
28+
assert len(evaluations) == len(criteria)
29+
assert all(
30+
isinstance(evaluation, UnitTestEvaluation)
31+
for evaluation in evaluations
32+
)
33+
assert all(
34+
evaluation.unit_test_id == unit_test.id for evaluation in evaluations
35+
)
36+
assert all(evaluation.model_id == model.id for evaluation in evaluations)
37+
38+
unit_test_slice = CLIENT.get_slice(unit_test.slice_id)
39+
item_evaluations = evaluations[0].item_evals
40+
assert isinstance(item_evaluations, list)
41+
assert len(item_evaluations) == len(
42+
unit_test_slice.items_and_annotations()
43+
)
44+
assert isinstance(item_evaluations[0], UnitTestItemEvaluation)
45+
assert all(
46+
eval.evaluation_id == evaluations[0].id for eval in item_evaluations
47+
)
48+
assert all(eval.unit_test_id == unit_test.id for eval in item_evaluations)
49+
50+
51+
@pytest.mark.integration
52+
@pytest.mark.xfail(
53+
reason="Missing predictions is currently treated as failure in evaluation."
54+
)
55+
@pytest.mark.skip
56+
def test_unit_test_evaluation_no_prediction_for_last_item(
57+
# TODO(gunnar): Remove this slow integration tests after this is confirmed and tested on the evaluation side.
58+
# there's no reason doing unit testing for evaluation here.
59+
CLIENT,
60+
modelci_dataset,
61+
unit_test,
62+
annotations,
63+
):
64+
uuid = get_uuid()
65+
model = CLIENT.create_model(
66+
f"[Model CI Test] {uuid}", reference_id=f"model_ci_{uuid}"
67+
)
68+
create_predictions(modelci_dataset, model, annotations[:-1])
2369
iou = CLIENT.modelci.eval_functions.bbox_iou
70+
# NOTE: Another criterion is defined in the unit_test fixture
2471
unit_test.add_criterion(iou() > EVAL_FUNCTION_THRESHOLD)
2572

2673
job: AsyncJob = CLIENT.modelci.evaluate_model_on_unit_tests(

0 commit comments

Comments
 (0)