Skip to content

Commit 6779b9a

Browse files
author
Anthony Krivonos
authored
[Validate] Scene-level scenario tests and custom metrics uploads (#372)
* [Validate] Support uploading custom scene-level metrics * Added scene ref id upload and other shit * Use existing constant * Pr comments * Level in results upload * Added get test * Add missing fixtures * Fixed tests I think hopefully finally ugh * Make tests less flaky
1 parent 12a7d0a commit 6779b9a

File tree

12 files changed

+208
-15
lines changed

12 files changed

+208
-15
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.14.27](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.27) - 2022-11-04
9+
10+
### Added
11+
- Support for scene-level external evaluation functions
12+
- Support for uploading custom scene-level metrics
13+
14+
815
## [0.14.26](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.26) - 2022-11-01
916

1017
### Added
@@ -27,6 +34,7 @@ dataset.get_scene_from_item_ref_id(some_item['item'].reference_id)
2734
- `slice.type == 'object'` => list of `Annotation`/`Prediction` objects
2835
- `slice.type == 'scene'` => list of `Scene` objects
2936

37+
3038
## [0.14.24](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.24) - 2022-10-19
3139

3240
### Fixed

nucleus/dataset_item.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def from_json(cls, payload: dict):
158158

159159
if BACKEND_REFERENCE_ID_KEY in payload:
160160
payload[REFERENCE_ID_KEY] = payload[BACKEND_REFERENCE_ID_KEY]
161+
161162
return cls(
162163
image_location=image_url,
163164
pointcloud_location=pointcloud_url,

nucleus/validate/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from nucleus.connection import Connection
44
from nucleus.job import AsyncJob
55

6-
from .constants import EVAL_FUNCTION_KEY, SCENARIO_TEST_ID_KEY
6+
from .constants import EVAL_FUNCTION_KEY, SCENARIO_TEST_ID_KEY, EntityLevel
77
from .data_transfer_objects.eval_function import (
88
CreateEvalFunction,
99
EvalFunctionEntry,
@@ -205,13 +205,15 @@ def metrics(self, model_id: str):
205205
def create_external_eval_function(
206206
self,
207207
name: str,
208+
level: EntityLevel = EntityLevel.ITEM,
208209
) -> EvalFunctionEntry:
209210
"""Creates a new external evaluation function. This external function can be used to upload evaluation
210211
results with functions defined and computed by the customer, without having to share the source code of the
211212
respective function.
212213
213214
Args:
214215
name: unique name of evaluation function
216+
level: level at which the eval function is run, defaults to "item"
215217
216218
Raises:
217219
- NucleusAPIError if the creation of the function fails on the server side
@@ -228,6 +230,7 @@ def create_external_eval_function(
228230
is_external_function=True,
229231
serialized_fn=None,
230232
raw_source=None,
233+
level=level,
231234
).dict(),
232235
"validate/eval_fn",
233236
)

nucleus/validate/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,10 @@ class ThresholdComparison(str, Enum):
2020
GREATER_THAN_EQUAL_TO = "greater_than_equal_to"
2121
LESS_THAN = "less_than"
2222
LESS_THAN_EQUAL_TO = "less_than_equal_to"
23+
24+
25+
class EntityLevel(str, Enum):
26+
"""Level for evaluation functions and unit tests."""
27+
28+
ITEM = "item"
29+
SCENE = "scene"

nucleus/validate/data_transfer_objects/eval_function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class CreateEvalFunction(ImmutableModel):
9191
is_external_function: bool
9292
serialized_fn: Optional[str] = None
9393
raw_source: Optional[str] = None
94+
level: Optional[str] = None
9495

9596
@validator("name")
9697
def name_is_valid(cls, v): # pylint: disable=no-self-argument

nucleus/validate/data_transfer_objects/scenario_test_evaluations.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,32 @@
1-
from typing import List
1+
from typing import Optional
22

3-
from pydantic import validator
3+
from pydantic import root_validator, validator
44

55
from nucleus.pydantic_base import ImmutableModel
66

77

88
class EvaluationResult(ImmutableModel):
9-
item_ref_id: str
10-
score: float
9+
item_ref_id: Optional[str] = None
10+
scene_ref_id: Optional[str] = None
11+
score: float = 0
1112
weight: float = 1
1213

14+
@root_validator()
15+
def is_item_or_scene_provided(
16+
cls, values
17+
): # pylint: disable=no-self-argument
18+
if (
19+
values.get("item_ref_id") is None
20+
and values.get("scene_ref_id") is None
21+
) or (
22+
(
23+
values.get("item_ref_id") is not None
24+
and values.get("scene_ref_id") is not None
25+
)
26+
):
27+
raise ValueError("Must provide either item_ref_id or scene_ref_id")
28+
return values
29+
1330
@validator("score", "weight")
1431
def is_normalized(cls, v): # pylint: disable=no-self-argument
1532
if 0 <= v <= 1:

nucleus/validate/scenario_test.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,19 @@
55
and have confidence that they’re always shipping the best model.
66
"""
77
from dataclasses import dataclass, field
8-
from typing import List, Optional
8+
from typing import List, Optional, Union
99

1010
from ..connection import Connection
11-
from ..constants import DATASET_ITEMS_KEY, NAME_KEY, SLICE_ID_KEY
11+
from ..constants import DATASET_ITEMS_KEY, NAME_KEY, SCENES_KEY, SLICE_ID_KEY
1212
from ..dataset_item import DatasetItem
13+
from ..scene import Scene
1314
from .constants import (
1415
EVAL_FUNCTION_ID_KEY,
1516
SCENARIO_TEST_ID_KEY,
1617
SCENARIO_TEST_METRICS_KEY,
1718
THRESHOLD_COMPARISON_KEY,
1819
THRESHOLD_KEY,
20+
EntityLevel,
1921
ThresholdComparison,
2022
)
2123
from .data_transfer_objects.scenario_test_evaluations import EvaluationResult
@@ -162,16 +164,31 @@ def get_eval_history(self) -> List[ScenarioTestEvaluation]:
162164
]
163165
return evaluations
164166

165-
def get_items(self) -> List[DatasetItem]:
167+
def get_items(
168+
self, level: EntityLevel = EntityLevel.ITEM
169+
) -> Union[List[DatasetItem], List[Scene]]:
170+
"""Gets items within a scenario test at a given level, returning a list of DatasetItem or Scene objects.
171+
172+
Args:
173+
level: :class:`EntityLevel`
174+
175+
Returns:
176+
A list of :class:`ScenarioTestEvaluation` objects.
177+
"""
166178
response = self.connection.get(
167179
f"validate/scenario_test/{self.id}/items",
168180
)
181+
if level == EntityLevel.SCENE:
182+
return [
183+
Scene.from_json(scene, skip_validate=True)
184+
for scene in response[SCENES_KEY]
185+
]
169186
return [
170187
DatasetItem.from_json(item) for item in response[DATASET_ITEMS_KEY]
171188
]
172189

173190
def set_baseline_model(self, model_id: str):
174-
"""Set's a new baseline model for the ScenarioTest. In order to be eligible to be a baseline,
191+
"""Sets a new baseline model for the ScenarioTest. In order to be eligible to be a baseline,
175192
this scenario test must have been evaluated using that model. The baseline model's performance
176193
is used as the threshold for all metrics against which other models are compared.
177194
@@ -205,14 +222,28 @@ def upload_external_evaluation_results(
205222
len(results) > 0
206223
), "Submitting evaluation requires at least one result."
207224

225+
level = EntityLevel.ITEM
208226
metric_per_ref_id = {}
209227
weight_per_ref_id = {}
210228
aggregate_weighted_sum = 0.0
211229
aggregate_weight = 0.0
230+
212231
# aggregation based on https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
213232
for r in results:
214-
metric_per_ref_id[r.item_ref_id] = r.score
215-
weight_per_ref_id[r.item_ref_id] = r.weight
233+
# Ensure results are uploaded ONLY for items or ONLY for scenes
234+
if r.scene_ref_id is not None:
235+
level = EntityLevel.SCENE
236+
if r.item_ref_id is not None and level == EntityLevel.SCENE:
237+
raise ValueError(
238+
"All evaluation results must either pertain to a scene_ref_id or an item_ref_id, not both."
239+
)
240+
ref_id = (
241+
r.item_ref_id if level == EntityLevel.ITEM else r.scene_ref_id
242+
)
243+
244+
# Aggregate scores and weights
245+
metric_per_ref_id[ref_id] = r.score
246+
weight_per_ref_id[ref_id] = r.weight
216247
aggregate_weighted_sum += r.score * r.weight
217248
aggregate_weight += r.weight
218249

@@ -224,6 +255,7 @@ def upload_external_evaluation_results(
224255
"overall_metric": aggregate_weighted_sum / aggregate_weight,
225256
"model_id": model_id,
226257
"slice_id": self.slice_id,
258+
"level": level.value,
227259
}
228260
response = self.connection.post(
229261
payload,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.14.26"
24+
version = "0.14.27"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/cli/conftest.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@ def module_scope_datasets(CLIENT):
2525
yield test_datasets
2626

2727

28+
@pytest.fixture(scope="module")
29+
def module_scope_scene_datasets(CLIENT):
30+
test_scene_datasets = []
31+
for i in range(3):
32+
dataset_name = f"[PyTest] CLI {i} {get_uuid()} (Scene)"
33+
test_scene_datasets.append(
34+
CLIENT.create_dataset(dataset_name, is_scene=True)
35+
)
36+
yield test_scene_datasets
37+
38+
2839
@pytest.fixture(scope="function")
2940
def function_scope_dataset(CLIENT):
3041
dataset = CLIENT.create_dataset(f"[PyTest] Dataset {get_uuid()}")
@@ -49,6 +60,11 @@ def populated_dataset(module_scope_datasets):
4960
yield module_scope_datasets[0]
5061

5162

63+
@pytest.fixture(scope="module")
64+
def populated_scene_dataset(module_scope_scene_datasets):
65+
yield module_scope_scene_datasets[0]
66+
67+
5268
@pytest.fixture(scope="module")
5369
def model(module_scope_models):
5470
yield module_scope_models[0]
@@ -76,6 +92,28 @@ def test_slice(populated_dataset, slice_items):
7692
yield slc
7793

7894

95+
@pytest.fixture(scope="module")
96+
def scenes(populated_dataset):
97+
items = make_dataset_items()
98+
populated_dataset.append(items)
99+
yield items
100+
101+
102+
@pytest.fixture(scope="module")
103+
def slice_scenes(scenes):
104+
yield scenes[:2]
105+
106+
107+
@pytest.fixture(scope="module")
108+
def test_scene_slice(populated_scene_dataset, slice_scenes):
109+
slice_name = "[PyTest] CLI Scene Slice"
110+
slc = populated_scene_dataset.create_slice(
111+
name=slice_name,
112+
reference_ids=[scene.reference_id for scene in slice_scenes],
113+
)
114+
yield slc
115+
116+
79117
@pytest.fixture(scope="module")
80118
def annotations(populated_dataset, slice_items):
81119
annotations = create_box_annotations(populated_dataset, slice_items)

tests/test_dataset.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from nucleus.errors import NucleusAPIError
3030
from nucleus.job import AsyncJob, JobError
31+
from nucleus.scene import LidarScene, VideoScene
3132

3233
from .helpers import (
3334
DATASET_WITH_EMBEDDINGS,
@@ -36,9 +37,11 @@
3637
TEST_CATEGORY_ANNOTATIONS,
3738
TEST_DATASET_NAME,
3839
TEST_IMG_URLS,
40+
TEST_LIDAR_SCENES,
3941
TEST_MULTICATEGORY_ANNOTATIONS,
4042
TEST_POLYGON_ANNOTATIONS,
4143
TEST_SEGMENTATION_ANNOTATIONS,
44+
TEST_VIDEO_SCENES,
4245
assert_partial_equality,
4346
reference_id_from_url,
4447
)
@@ -94,6 +97,10 @@ def make_dataset_items():
9497
return ds_items_with_metadata
9598

9699

100+
def make_scenes():
101+
return [VideoScene.from_json(s) for s in TEST_VIDEO_SCENES["scenes"]]
102+
103+
97104
def test_dataset_create_and_delete_no_scene(CLIENT):
98105
# Creation
99106
ds = CLIENT.create_dataset(TEST_DATASET_NAME)

tests/validate/conftest.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
create_predictions,
1010
get_uuid,
1111
)
12-
from tests.test_dataset import make_dataset_items
12+
from tests.test_dataset import make_dataset_items, make_scenes
1313

1414

1515
@pytest.fixture(scope="module")
@@ -40,6 +40,56 @@ def test_slice(validate_dataset, slice_items):
4040
yield slc
4141

4242

43+
@pytest.fixture(scope="module")
44+
def module_scope_datasets(CLIENT):
45+
test_datasets = []
46+
for i in range(3):
47+
dataset_name = f"[PyTest] CLI {i} {get_uuid()}"
48+
test_datasets.append(
49+
CLIENT.create_dataset(dataset_name, is_scene=False)
50+
)
51+
yield test_datasets
52+
53+
54+
@pytest.fixture(scope="module")
55+
def module_scope_scene_datasets(CLIENT):
56+
test_scene_datasets = []
57+
for i in range(3):
58+
dataset_name = f"[PyTest] CLI {i} {get_uuid()} (Scene)"
59+
test_scene_datasets.append(
60+
CLIENT.create_dataset(dataset_name, is_scene=True)
61+
)
62+
yield test_scene_datasets
63+
64+
65+
@pytest.fixture(scope="module")
66+
def populated_scene_dataset(module_scope_scene_datasets):
67+
yield module_scope_scene_datasets[0]
68+
69+
70+
@pytest.fixture(scope="module")
71+
def slice_scenes():
72+
scenes = make_scenes()[:1]
73+
yield scenes
74+
75+
76+
@pytest.fixture(scope="module")
77+
def scenes(populated_scene_dataset, slice_scenes):
78+
job = populated_scene_dataset.append(slice_scenes, asynchronous=True)
79+
job.sleep_until_complete()
80+
yield slice_scenes
81+
82+
83+
@pytest.fixture(scope="module")
84+
def test_scene_slice(populated_scene_dataset, scenes):
85+
slice_name = "[PyTest] CLI Scene Slice"
86+
slc = populated_scene_dataset.create_slice(
87+
name=slice_name,
88+
reference_ids=[scene.reference_id for scene in scenes],
89+
)
90+
yield slc
91+
92+
4393
@pytest.fixture(scope="module")
4494
def model(CLIENT):
4595
model_reference = "model_" + str(time.time())

0 commit comments

Comments
 (0)