Skip to content

Commit 2028e55

Browse files
author
Anthony Krivonos
authored
[Validate] Track-level metrics upload (#375)
* [Validate] Track-level metrics upload * Refactor upload_external_evaluation_results * Export EntityLevel * Version bump * Fix test and pr fixes
1 parent c9f309a commit 2028e55

File tree

14 files changed

+202
-51
lines changed

14 files changed

+202
-51
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.14.30](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.30) - 2022-11-29
9+
10+
### Added
11+
- Support for uploading track-level metrics to external evaluation functions using track_ref_ids
12+
813
## [0.14.29](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.29) - 2022-11-22
914

1015
### Added

nucleus/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(
178178
import tqdm.notebook as tqdm_notebook
179179

180180
self.tqdm_bar = tqdm_notebook.tqdm
181-
self._connection = Connection(self.api_key, self.endpoint)
181+
self.connection = Connection(self.api_key, self.endpoint)
182182
self.validate = Validate(self.api_key, self.endpoint)
183183

184184
def __repr__(self):
@@ -1014,16 +1014,16 @@ def create_object_index(
10141014
)
10151015

10161016
def delete(self, route: str):
1017-
return self._connection.delete(route)
1017+
return self.connection.delete(route)
10181018

10191019
def get(self, route: str):
1020-
return self._connection.get(route)
1020+
return self.connection.get(route)
10211021

10221022
def post(self, payload: dict, route: str):
1023-
return self._connection.post(payload, route)
1023+
return self.connection.post(payload, route)
10241024

10251025
def put(self, payload: dict, route: str):
1026-
return self._connection.put(payload, route)
1026+
return self.connection.put(payload, route)
10271027

10281028
# TODO: Fix return type, can be a list as well. Brings on a lot of mypy errors ...
10291029
def make_request(
@@ -1054,7 +1054,7 @@ def make_request(
10541054
"Received defined payload with GET request! Will ignore payload"
10551055
)
10561056
payload = None
1057-
return self._connection.make_request(payload, route, requests_command, return_raw_response) # type: ignore
1057+
return self.connection.make_request(payload, route, requests_command, return_raw_response) # type: ignore
10581058

10591059
def _set_api_key(self, api_key):
10601060
"""Fetch API key from environment variable NUCLEUS_API_KEY if not set"""

nucleus/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1863,7 +1863,7 @@ def tracks(self) -> List[Track]:
18631863
tracks_list = [
18641864
Track.from_json(
18651865
payload=track,
1866-
client=self._client,
1866+
connection=self._client.connection,
18671867
)
18681868
for track in response[TRACKS_KEY]
18691869
]

nucleus/scene.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,10 @@ def from_json(
330330
frames = [Frame.from_json(frame) for frame in frames_payload]
331331
tracks_payload = payload.get(TRACKS_KEY, [])
332332
tracks = (
333-
[Track.from_json(track, client) for track in tracks_payload]
333+
[
334+
Track.from_json(track, connection=client.connection)
335+
for track in tracks_payload
336+
]
334337
if client
335338
else []
336339
)
@@ -680,7 +683,10 @@ def from_json(
680683
items = [DatasetItem.from_json(item) for item in items_payload]
681684
tracks_payload = payload.get(TRACKS_KEY, [])
682685
tracks = (
683-
[Track.from_json(track, client) for track in tracks_payload]
686+
[
687+
Track.from_json(track, connection=client.connection)
688+
for track in tracks_payload
689+
]
684690
if client
685691
else []
686692
)

nucleus/track.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313

1414
if TYPE_CHECKING:
15-
from . import NucleusClient
15+
from . import Connection
1616

1717

1818
@dataclass # pylint: disable=R0902
@@ -25,7 +25,7 @@ class Track: # pylint: disable=R0902
2525
metadata: Arbitrary key/value dictionary of info to attach to this track.
2626
"""
2727

28-
_client: "NucleusClient"
28+
_connection: "Connection"
2929
dataset_id: str
3030
reference_id: str
3131
metadata: Optional[dict] = None
@@ -41,10 +41,10 @@ def __eq__(self, other):
4141
)
4242

4343
@classmethod
44-
def from_json(cls, payload: dict, client: "NucleusClient"):
44+
def from_json(cls, payload: dict, connection: "Connection"):
4545
"""Instantiates track object from schematized JSON dict payload."""
4646
return cls(
47-
_client=client,
47+
_connection=connection,
4848
reference_id=str(payload[REFERENCE_ID_KEY]),
4949
dataset_id=str(payload[DATASET_ID_KEY]),
5050
metadata=payload.get(METADATA_KEY, None),
@@ -79,7 +79,7 @@ def update(
7979
entire metadata object will be overwritten. Otherwise, only the keys in metadata will be overwritten.
8080
"""
8181

82-
self._client.make_request(
82+
self._connection.make_request(
8383
payload={
8484
REFERENCE_ID_KEY: self.reference_id,
8585
METADATA_KEY: metadata,

nucleus/validate/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
]
88

99
from .client import Validate
10-
from .constants import ThresholdComparison
10+
from .constants import EntityLevel, ThresholdComparison
1111
from .data_transfer_objects.eval_function import (
1212
EvalFunctionEntry,
1313
EvaluationCriterion,

nucleus/validate/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def create_external_eval_function(
213213
214214
Args:
215215
name: unique name of evaluation function
216-
level: level at which the eval function is run, defaults to "item"
216+
level: level at which the eval function is run, defaults to EntityLevel.ITEM.
217217
218218
Raises:
219219
- NucleusAPIError if the creation of the function fails on the server side

nucleus/validate/constants.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@ class ThresholdComparison(str, Enum):
2323

2424

2525
class EntityLevel(str, Enum):
26-
"""Level for evaluation functions and unit tests."""
26+
"""
27+
Data level at which evaluation functions produce outputs.
28+
For instance, when comparing results across dataset items, use
29+
`EntityLevel.ITEM`. For scenes, use `EntityLevel.SCENE`. Finally,
30+
when comparing results between tracks within a single scene or a
31+
holistic item datset, use `EntityLevel.TRACK`.
32+
"""
2733

34+
TRACK = "track"
2835
ITEM = "item"
2936
SCENE = "scene"

nucleus/validate/data_transfer_objects/scenario_test_evaluations.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
class EvaluationResult(ImmutableModel):
9+
track_ref_id: Optional[str] = None
910
item_ref_id: Optional[str] = None
1011
scene_ref_id: Optional[str] = None
1112
score: float = 0
@@ -15,16 +16,15 @@ class EvaluationResult(ImmutableModel):
1516
def is_item_or_scene_provided(
1617
cls, values
1718
): # pylint: disable=no-self-argument
18-
if (
19-
values.get("item_ref_id") is None
20-
and values.get("scene_ref_id") is None
21-
) or (
22-
(
23-
values.get("item_ref_id") is not None
24-
and values.get("scene_ref_id") is not None
19+
ref_ids = [
20+
values.get("track_ref_id", None),
21+
values.get("item_ref_id", None),
22+
values.get("scene_ref_id", None),
23+
]
24+
if len([ref_id for ref_id in ref_ids if ref_id is not None]) != 1:
25+
raise ValueError(
26+
"Must provide exactly one of track_ref_id, item_ref_id, or scene_ref_id"
2527
)
26-
):
27-
raise ValueError("Must provide either item_ref_id or scene_ref_id")
2828
return values
2929

3030
@validator("score", "weight")

nucleus/validate/scenario_test.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,16 @@
88
from typing import List, Optional, Union
99

1010
from ..connection import Connection
11-
from ..constants import DATASET_ITEMS_KEY, NAME_KEY, SCENES_KEY, SLICE_ID_KEY
11+
from ..constants import (
12+
DATASET_ITEMS_KEY,
13+
NAME_KEY,
14+
SCENES_KEY,
15+
SLICE_ID_KEY,
16+
TRACKS_KEY,
17+
)
1218
from ..dataset_item import DatasetItem
1319
from ..scene import Scene
20+
from ..track import Track
1421
from .constants import (
1522
EVAL_FUNCTION_ID_KEY,
1623
SCENARIO_TEST_ID_KEY,
@@ -166,8 +173,8 @@ def get_eval_history(self) -> List[ScenarioTestEvaluation]:
166173

167174
def get_items(
168175
self, level: EntityLevel = EntityLevel.ITEM
169-
) -> Union[List[DatasetItem], List[Scene]]:
170-
"""Gets items within a scenario test at a given level, returning a list of DatasetItem or Scene objects.
176+
) -> Union[List[Track], List[DatasetItem], List[Scene]]:
177+
"""Gets items within a scenario test at a given level, returning a list of Track, DatasetItem, or Scene objects.
171178
172179
Args:
173180
level: :class:`EntityLevel`
@@ -178,14 +185,22 @@ def get_items(
178185
response = self.connection.get(
179186
f"validate/scenario_test/{self.id}/items",
180187
)
188+
if level == EntityLevel.TRACK:
189+
return [
190+
Track.from_json(track, connection=self.connection)
191+
for track in response.get(TRACKS_KEY, [])
192+
]
181193
if level == EntityLevel.SCENE:
182194
return [
183195
Scene.from_json(scene, skip_validate=True)
184-
for scene in response[SCENES_KEY]
196+
for scene in response.get(SCENES_KEY, [])
185197
]
186-
return [
187-
DatasetItem.from_json(item) for item in response[DATASET_ITEMS_KEY]
188-
]
198+
if level == EntityLevel.ITEM:
199+
return [
200+
DatasetItem.from_json(item)
201+
for item in response.get(DATASET_ITEMS_KEY, [])
202+
]
203+
raise ValueError(f"Invalid entity level: {level}")
189204

190205
def set_baseline_model(self, model_id: str):
191206
"""Sets a new baseline model for the ScenarioTest. In order to be eligible to be a baseline,
@@ -222,23 +237,41 @@ def upload_external_evaluation_results(
222237
len(results) > 0
223238
), "Submitting evaluation requires at least one result."
224239

225-
level = EntityLevel.ITEM
240+
level: Optional[EntityLevel] = None
226241
metric_per_ref_id = {}
227242
weight_per_ref_id = {}
228243
aggregate_weighted_sum = 0.0
229244
aggregate_weight = 0.0
230245

246+
# Ensures reults at only one EntityLevel are provided, otherwise throwing a ValueError
247+
def ensure_level_consistency_or_raise(
248+
cur_level: Optional[EntityLevel], new_level: EntityLevel
249+
):
250+
if level is not None and level != new_level:
251+
raise ValueError(
252+
f"All evaluation results must only pertain to one level. Received {cur_level} then {new_level}"
253+
)
254+
231255
# aggregation based on https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
232256
for r in results:
233-
# Ensure results are uploaded ONLY for items or ONLY for scenes
257+
# Ensure results are uploaded ONLY for ONE OF tracks, items, and scenes
258+
if r.track_ref_id is not None:
259+
ensure_level_consistency_or_raise(level, EntityLevel.TRACK)
260+
level = EntityLevel.TRACK
261+
if r.item_ref_id is not None:
262+
ensure_level_consistency_or_raise(level, EntityLevel.ITEM)
263+
level = EntityLevel.ITEM
234264
if r.scene_ref_id is not None:
265+
ensure_level_consistency_or_raise(level, EntityLevel.SCENE)
235266
level = EntityLevel.SCENE
236-
if r.item_ref_id is not None and level == EntityLevel.SCENE:
237-
raise ValueError(
238-
"All evaluation results must either pertain to a scene_ref_id or an item_ref_id, not both."
239-
)
240267
ref_id = (
241-
r.item_ref_id if level == EntityLevel.ITEM else r.scene_ref_id
268+
r.track_ref_id
269+
if level == EntityLevel.TRACK
270+
else (
271+
r.item_ref_id
272+
if level == EntityLevel.ITEM
273+
else r.scene_ref_id
274+
)
242275
)
243276

244277
# Aggregate scores and weights
@@ -255,7 +288,7 @@ def upload_external_evaluation_results(
255288
"overall_metric": aggregate_weighted_sum / aggregate_weight,
256289
"model_id": model_id,
257290
"slice_id": self.slice_id,
258-
"level": level.value,
291+
"level": level.value if level else None,
259292
}
260293
response = self.connection.post(
261294
payload,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.14.29"
24+
version = "0.14.30"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
DATASET_WITH_EMBEDDINGS = "ds_c8jwdhy4y4f0078hzceg"
2121
NUCLEUS_PYTEST_USER_ID = "60ad648c85db770026e9bf77"
2222

23+
EVAL_FUNCTION_NAME = "eval_fn"
2324
EVAL_FUNCTION_THRESHOLD = 0.5
2425
EVAL_FUNCTION_COMPARISON = ThresholdComparison.GREATER_THAN_EQUAL_TO
2526

tests/test_track.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from copy import deepcopy
23

34
import pytest
@@ -69,7 +70,7 @@ def test_create_mp_with_tracks(CLIENT, dataset_scene):
6970
expected_track_reference_ids = [
7071
ann["track_reference_id"] for ann in TEST_SCENE_BOX_PREDS_WITH_TRACK
7172
]
72-
model_reference = "model_test_create_mp_with_tracks"
73+
model_reference = "model_" + str(time.time())
7374
model = CLIENT.create_model(TEST_MODEL_NAME, model_reference)
7475

7576
# Act

0 commit comments

Comments
 (0)