Skip to content

Commit c641abb

Browse files
ntamas92gatli
andauthored
Implement dataset.query_objects method (#402)
* Implement dataset.query_objects method * Remove true_negative from enum * Fix confusion category true positive case * Rename IOUMatch to EvaluationMatch * Rename documentation * Add documentation to EvaluationMatch * Propagate model_run_id parameter * Bump sdk version --------- Co-authored-by: Gunnar Atli Thoroddsen <gunnar.thoroddsen@scale.com>
1 parent 27c7dfd commit c641abb

File tree

6 files changed

+178
-5
lines changed

6 files changed

+178
-5
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88

9+
## [0.16.4](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.4) - 2023-10-23
10+
11+
### Added
12+
- Added a `query_objects` method on the Dataset class.
13+
- Example
14+
```shell
15+
>>> ds = client.get_dataset('ds_id')
16+
>>> objects = ds.query_objects('annotations.metadata.distance_to_device > 150', ObjectQueryType.GROUND_TRUTH_ONLY)
17+
[CuboidAnnotation(label="", dimensions={}, ...), ...]
18+
```
19+
- Added `EvaluationMatch` class to represent IOU Matches, False Positives and False Negatives retrieved through the `query_objects` method
20+
21+
922
## [0.16.3](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.16.3) - 2023-10-10
1023

1124
### Added

nucleus/async_job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def result_urls(self, wait_for_completion=True) -> List[str]:
177177
178178
Parameters:
179179
wait_for_completion: Defines whether the call shall wait for
180-
the job to complete. Defaults to True
180+
the job to complete. Defaults to True
181181
182182
Returns:
183183
A list of signed Scale URLs which contain batches of embeddings.

nucleus/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
FX_KEY = "fx"
6060
FY_KEY = "fy"
6161
GEOMETRY_KEY = "geometry"
62+
GROUND_TRUTH_ANNOTATION_ID_KEY = "ground_truth_annotation_id"
63+
GROUND_TRUTH_ANNOTATION_LABEL_KEY = "ground_truth_annotation_label"
6264
HEADING_KEY = "heading"
6365
HEIGHT_KEY = "height"
6466
ID_KEY = "id"
@@ -68,6 +70,7 @@
6870
IMAGE_URL_KEY = "image_url"
6971
INDEX_KEY = "index"
7072
INDEX_CONTINUOUS_ENABLE_KEY = "enable"
73+
IOU_KEY = "iou"
7174
ITEMS_KEY = "items"
7275
ITEM_KEY = "item"
7376
ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema"
@@ -97,6 +100,8 @@
97100
MODEL_TAGS_KEY = "tags"
98101
MODEL_ID_KEY = "model_id"
99102
MODEL_RUN_ID_KEY = "model_run_id"
103+
MODEL_PREDICTION_ID_KEY = "model_prediction_id"
104+
MODEL_PREDICTION_LABEL_KEY = "model_prediction_label"
100105
NAME_KEY = "name"
101106
NEW_ITEMS = "new_items"
102107
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
@@ -135,6 +140,7 @@
135140
TRACK_REFERENCE_ID_KEY = "track_reference_id"
136141
TRACK_REFERENCE_IDS_KEY = "track_reference_ids"
137142
TRACKS_KEY = "tracks"
143+
TRUE_POSITIVE_KEY = "true_positive"
138144
TYPE_KEY = "type"
139145
UPDATED_ITEMS = "updated_items"
140146
UPDATE_KEY = "update"

nucleus/dataset.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import os
3+
from enum import Enum
34
from typing import (
45
TYPE_CHECKING,
56
Any,
@@ -16,7 +17,8 @@
1617

1718
from nucleus.annotation_uploader import AnnotationUploader, PredictionUploader
1819
from nucleus.async_job import AsyncJob, EmbeddingsExportJob
19-
from nucleus.prediction import Prediction, from_json
20+
from nucleus.evaluation_match import EvaluationMatch
21+
from nucleus.prediction import from_json as prediction_from_json
2022
from nucleus.track import Track
2123
from nucleus.url_utils import sanitize_string_args
2224
from nucleus.utils import (
@@ -77,6 +79,7 @@
7779
construct_model_run_creation_payload,
7880
construct_taxonomy_payload,
7981
)
82+
from .prediction import Prediction
8083
from .scene import LidarScene, Scene, VideoScene, check_all_scene_paths_remote
8184
from .slice import (
8285
Slice,
@@ -98,6 +101,14 @@
98101
WARN_FOR_LARGE_SCENES_UPLOAD = 5
99102

100103

104+
class ObjectQueryType(str, Enum):
105+
IOU = "iou"
106+
FALSE_POSITIVE = "false_positive"
107+
FALSE_NEGATIVE = "false_negative"
108+
PREDICTIONS_ONLY = "predictions_only"
109+
GROUND_TRUTH_ONLY = "ground_truth_only"
110+
111+
101112
class Dataset:
102113
"""Datasets are collections of your data that can be associated with models.
103114
@@ -1681,7 +1692,7 @@ def upload_predictions(
16811692
:class:`Category<CategoryPrediction>`, and :class:`Category<SceneCategoryPrediction>` predictions. Cuboid predictions
16821693
can only be uploaded to a :class:`pointcloud DatasetItem<LidarScene>`.
16831694
1684-
When uploading an prediction, you need to specify which item you are
1695+
When uploading a prediction, you need to specify which item you are
16851696
annotating via the reference_id you provided when uploading the image
16861697
or pointcloud.
16871698
@@ -1854,7 +1865,7 @@ def prediction_loc(self, model, reference_id, annotation_id):
18541865
:class:`KeypointsPrediction` \
18551866
]: Model prediction object with the specified annotation ID.
18561867
"""
1857-
return from_json(
1868+
return prediction_from_json(
18581869
self._client.make_request(
18591870
payload=None,
18601871
route=f"dataset/{self.id}/model/{model.id}/loc/{reference_id}/{annotation_id}",
@@ -1999,6 +2010,47 @@ def query_scenes(self, query: str) -> Iterable[Scene]:
19992010
for item_json in json_generator:
20002011
yield Scene.from_json(item_json, None, True)
20012012

2013+
def query_objects(
2014+
self,
2015+
query: str,
2016+
query_type: ObjectQueryType,
2017+
model_run_id: Optional[str] = None,
2018+
) -> Iterable[Union[Annotation, Prediction, EvaluationMatch]]:
2019+
"""
2020+
Fetches all objects in the dataset that pertain to a given structured query.
2021+
The results are either Predictions, Annotations, or Evaluation Matches, based on the objectType input parameter
2022+
2023+
Args:
2024+
query: Structured query compatible with the `Nucleus query language <https://nucleus.scale.com/docs/query-language-reference>`_.
2025+
objectType: Defines the type of the object to query
2026+
2027+
Returns:
2028+
An iterable of either Predictions, Annotations, or Evaluation Matches
2029+
"""
2030+
json_generator = paginate_generator(
2031+
client=self._client,
2032+
endpoint=f"dataset/{self.id}/queryObjectsPage",
2033+
result_key=ITEMS_KEY,
2034+
page_size=MAX_ES_PAGE_SIZE,
2035+
query=query,
2036+
patch_mode=query_type,
2037+
model_run_id=model_run_id,
2038+
)
2039+
2040+
for item_json in json_generator:
2041+
if query_type == ObjectQueryType.GROUND_TRUTH_ONLY:
2042+
yield Annotation.from_json(item_json)
2043+
elif query_type == ObjectQueryType.PREDICTIONS_ONLY:
2044+
yield prediction_from_json(item_json)
2045+
elif query_type in [
2046+
ObjectQueryType.IOU,
2047+
ObjectQueryType.FALSE_POSITIVE,
2048+
ObjectQueryType.FALSE_NEGATIVE,
2049+
]:
2050+
yield EvaluationMatch.from_json(item_json)
2051+
else:
2052+
raise ValueError("Unknown object type", query_type)
2053+
20022054
@property
20032055
def tracks(self) -> List[Track]:
20042056
"""Tracks unique to this dataset.

nucleus/evaluation_match.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from typing import Optional
4+
5+
from .constants import (
6+
DATASET_ITEM_ID_KEY,
7+
GROUND_TRUTH_ANNOTATION_ID_KEY,
8+
GROUND_TRUTH_ANNOTATION_LABEL_KEY,
9+
IOU_KEY,
10+
MODEL_PREDICTION_ID_KEY,
11+
MODEL_PREDICTION_LABEL_KEY,
12+
MODEL_RUN_ID_KEY,
13+
TRUE_POSITIVE_KEY,
14+
)
15+
16+
17+
class ConfusionCategory(Enum):
18+
TRUE_POSITIVE = "true_positive"
19+
FALSE_POSITIVE = "false_positive"
20+
FALSE_NEGATIVE = "false_negative"
21+
22+
23+
def infer_confusion_category(
24+
true_positive: bool,
25+
ground_truth_annotation_label: str,
26+
model_prediction_label: str,
27+
):
28+
confusion_category = ConfusionCategory.FALSE_NEGATIVE
29+
30+
if (
31+
true_positive
32+
or model_prediction_label == ground_truth_annotation_label
33+
):
34+
confusion_category = ConfusionCategory.TRUE_POSITIVE
35+
elif model_prediction_label is not None:
36+
confusion_category = ConfusionCategory.FALSE_POSITIVE
37+
38+
return confusion_category
39+
40+
41+
@dataclass
42+
class EvaluationMatch:
43+
"""
44+
EvaluationMatch is a result from a model run evaluation. It can represent a true positive, false positive,
45+
or false negative.
46+
47+
The matching only matches the strongest prediction for each annotation, so if there are multiple predictions
48+
that overlap a single annotation only the one with the highest overlap metric will be matched.
49+
50+
The model prediction label and the ground truth annotation label can differ for true positives if there is configured
51+
an allowed_label_mapping for the model run.
52+
53+
NOTE: There is no iou thresholding applied to these matches, so it is possible to have a true positive with a low
54+
iou score. If manually rejecting matches remember that a rejected match produces both a false positive and a false
55+
negative otherwise you'll skew your aggregates.
56+
57+
Attributes:
58+
model_run_id (str): The ID of the model run that produced this match.
59+
model_prediction_id (str): The ID of the model prediction that was matched. None if the match was a false negative.
60+
ground_truth_annotation_id (str): The ID of the ground truth annotation that was matched. None if the match was a false positive.
61+
iou (int): The intersection over union score of the match.
62+
dataset_item_id (str): The ID of the dataset item that was matched.
63+
confusion_category (ConfusionCategory): The confusion category of the match.
64+
model_prediction_label (str): The label of the model prediction that was matched. None if the match was a false negative.
65+
ground_truth_annotation_label (str): The label of the ground truth annotation that was matched. None if the match was a false positive.
66+
"""
67+
68+
model_run_id: str
69+
model_prediction_id: Optional[str] # field is nullable
70+
ground_truth_annotation_id: Optional[str] # field is nullable
71+
iou: float
72+
dataset_item_id: str
73+
confusion_category: ConfusionCategory
74+
model_prediction_label: Optional[str] # field is nullable
75+
ground_truth_annotation_label: Optional[str] # field is nullable
76+
77+
@classmethod
78+
def from_json(cls, payload: dict):
79+
is_true_positive = payload.get(TRUE_POSITIVE_KEY, False)
80+
model_prediction_label = payload.get(MODEL_PREDICTION_LABEL_KEY, None)
81+
ground_truth_annotation_label = payload.get(
82+
GROUND_TRUTH_ANNOTATION_LABEL_KEY, None
83+
)
84+
85+
confusion_category = infer_confusion_category(
86+
true_positive=is_true_positive,
87+
ground_truth_annotation_label=ground_truth_annotation_label,
88+
model_prediction_label=model_prediction_label,
89+
)
90+
91+
return cls(
92+
model_run_id=payload[MODEL_RUN_ID_KEY],
93+
model_prediction_id=payload.get(MODEL_PREDICTION_ID_KEY, None),
94+
ground_truth_annotation_id=payload.get(
95+
GROUND_TRUTH_ANNOTATION_ID_KEY, None
96+
),
97+
iou=payload[IOU_KEY],
98+
dataset_item_id=payload[DATASET_ITEM_ID_KEY],
99+
confusion_category=confusion_category,
100+
model_prediction_label=model_prediction_label,
101+
ground_truth_annotation_label=ground_truth_annotation_label,
102+
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ ignore = ["E501", "E741", "E731", "F401"] # Easy ignore for getting it running
2525

2626
[tool.poetry]
2727
name = "scale-nucleus"
28-
version = "0.16.3"
28+
version = "0.16.4"
2929
description = "The official Python client library for Nucleus, the Data Platform for AI"
3030
license = "MIT"
3131
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

0 commit comments

Comments
 (0)