Skip to content

Commit 1e51265

Browse files
author
Claire Pajot
committed
First pass
1 parent fd633b5 commit 1e51265

File tree

6 files changed

+216
-17
lines changed

6 files changed

+216
-17
lines changed

nucleus/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
CuboidPrediction,
8888
PolygonPrediction,
8989
SegmentationPrediction,
90+
CategoryPrediction,
9091
)
9192
from .scene import Frame, LidarScene
9293
from .slice import Slice
@@ -742,6 +743,7 @@ def predict(
742743
PolygonPrediction,
743744
CuboidPrediction,
744745
SegmentationPrediction,
746+
CategoryPrediction,
745747
]
746748
],
747749
model_run_id: Optional[str] = None,
@@ -752,7 +754,7 @@ def predict(
752754
):
753755
"""
754756
Uploads model outputs as predictions for a model_run. Returns info about the upload.
755-
:param annotations: List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
757+
:param annotations: List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction, CategoryPrediction]],
756758
:param update: bool
757759
:return:
758760
{
@@ -914,7 +916,7 @@ def predictions_ref_id(self, model_run_id: str, ref_id: str):
914916
:param reference_id: reference_id of a dataset item.
915917
:return:
916918
{
917-
"annotations": List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
919+
"annotations": List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction, CategoryPrediction]],
918920
}
919921
"""
920922
return self.make_request(
@@ -939,7 +941,7 @@ def predictions_iloc(self, model_run_id: str, i: int):
939941
:param i: absolute number of Dataset Item for a dataset corresponding to the model run.
940942
:return:
941943
{
942-
"annotations": List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
944+
"annotations": List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction, CategoryPrediction]],
943945
}
944946
"""
945947
return self.make_request(

nucleus/payload_constructor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
CuboidPrediction,
1515
PolygonPrediction,
1616
SegmentationPrediction,
17+
CategoryPrediction,
1718
)
1819
from .constants import (
1920
ANNOTATION_UPDATE_KEY,
@@ -91,7 +92,12 @@ def construct_segmentation_payload(
9192

9293
def construct_box_predictions_payload(
9394
box_predictions: List[
94-
Union[BoxPrediction, PolygonPrediction, CuboidPrediction]
95+
Union[
96+
BoxPrediction,
97+
PolygonPrediction,
98+
CuboidPrediction,
99+
CategoryPrediction,
100+
]
95101
],
96102
update: bool,
97103
) -> dict:

nucleus/prediction.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Dict, Optional, List
22
from .annotation import (
33
BoxAnnotation,
4+
CategoryAnnotation,
45
Point,
56
PolygonAnnotation,
67
Segment,
@@ -13,10 +14,12 @@
1314
BOX_TYPE,
1415
CUBOID_TYPE,
1516
POLYGON_TYPE,
17+
CATEGORY_TYPE,
1618
REFERENCE_ID_KEY,
1719
METADATA_KEY,
1820
GEOMETRY_KEY,
1921
LABEL_KEY,
22+
TAXONOMY_NAME_KEY,
2023
TYPE_KEY,
2124
X_KEY,
2225
Y_KEY,
@@ -40,6 +43,8 @@ def from_json(payload: dict):
4043
return PolygonPrediction.from_json(payload)
4144
elif payload.get(TYPE_KEY, None) == CUBOID_TYPE:
4245
return CuboidPrediction.from_json(payload)
46+
elif payload.get(TYPE_KEY, None) == CATEGORY_TYPE:
47+
return CategoryPrediction.from_json(payload)
4348
else:
4449
return SegmentationPrediction.from_json(payload)
4550

@@ -207,3 +212,43 @@ def from_json(cls, payload: dict):
207212
metadata=payload.get(METADATA_KEY, {}),
208213
class_pdf=payload.get(CLASS_PDF_KEY, None),
209214
)
215+
216+
217+
class CategoryPrediction(CategoryAnnotation):
218+
def __init__(
219+
self,
220+
label: str,
221+
taxonomy_name: str,
222+
reference_id: str,
223+
confidence: Optional[float] = None,
224+
metadata: Optional[Dict] = None,
225+
class_pdf: Optional[Dict] = None,
226+
):
227+
super().__init__(
228+
label=label,
229+
taxonomy_name=taxonomy_name,
230+
reference_id=reference_id,
231+
metadata=metadata,
232+
)
233+
self.confidence = confidence
234+
self.class_pdf = class_pdf
235+
236+
def to_payload(self) -> dict:
237+
payload = super().to_payload()
238+
if self.confidence is not None:
239+
payload[CONFIDENCE_KEY] = self.confidence
240+
if self.class_pdf is not None:
241+
payload[CLASS_PDF_KEY] = self.class_pdf
242+
243+
return payload
244+
245+
@classmethod
246+
def from_json(cls, payload: dict):
247+
return cls(
248+
label=payload.get(LABEL_KEY, 0),
249+
taxonomy_name=payload.get(TAXONOMY_NAME_KEY, None),
250+
reference_id=payload[REFERENCE_ID_KEY],
251+
confidence=payload.get(CONFIDENCE_KEY, None),
252+
metadata=payload.get(METADATA_KEY, {}),
253+
class_pdf=payload.get(CLASS_PDF_KEY, None),
254+
)

nucleus/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from .dataset_item import DatasetItem
3535
from .prediction import (
3636
BoxPrediction,
37+
CategoryPrediction,
3738
CuboidPrediction,
3839
PolygonPrediction,
3940
SegmentationPrediction,
@@ -50,6 +51,7 @@ def format_prediction_response(
5051
BoxPrediction,
5152
PolygonPrediction,
5253
CuboidPrediction,
54+
CategoryPrediction,
5355
SegmentationPrediction,
5456
]
5557
],
@@ -65,12 +67,14 @@ def format_prediction_response(
6567
Type[BoxPrediction],
6668
Type[PolygonPrediction],
6769
Type[CuboidPrediction],
70+
Type[CategoryPrediction],
6871
Type[SegmentationPrediction],
6972
],
7073
] = {
7174
BOX_TYPE: BoxPrediction,
7275
POLYGON_TYPE: PolygonPrediction,
7376
CUBOID_TYPE: CuboidPrediction,
77+
CATEGORY_TYPE: CategoryPrediction,
7478
SEGMENTATION_TYPE: SegmentationPrediction,
7579
}
7680
for type_key in annotation_payload:

tests/helpers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,11 @@ def reference_id_from_url(url):
206206
for polygon_annotation in TEST_POLYGON_ANNOTATIONS
207207
}
208208

209+
TEST_CATEGORY_MODEL_PDF = {
210+
category_annotation["label"]: 1 / len(TEST_CATEGORY_ANNOTATIONS)
211+
for category_annotation in TEST_CATEGORY_ANNOTATIONS
212+
}
213+
209214
TEST_BOX_PREDICTIONS = [
210215
{
211216
**TEST_BOX_ANNOTATIONS[i],
@@ -234,6 +239,20 @@ def reference_id_from_url(url):
234239
for i in range(len(TEST_POLYGON_ANNOTATIONS))
235240
]
236241

242+
TEST_CATEGORY_PREDICTIONS = [
243+
{
244+
**TEST_CATEGORY_ANNOTATIONS[i],
245+
"confidence": 0.10 * i,
246+
"class_pdf": TEST_CATEGORY_MODEL_PDF,
247+
}
248+
if i != 0
249+
else {
250+
**TEST_CATEGORY_ANNOTATIONS[i],
251+
"confidence": 0.10 * i,
252+
}
253+
for i in range(len(TEST_CATEGORY_ANNOTATIONS))
254+
]
255+
237256
TEST_INDEX_EMBEDDINGS_FILE = "https://raw.githubusercontent.com/scaleapi/nucleus-python-client/master/tests/testdata/pytest_embeddings_payload.json"
238257

239258

@@ -339,3 +358,12 @@ def assert_polygon_prediction_matches_dict(
339358
prediction_instance, prediction_dict
340359
)
341360
assert prediction_instance.confidence == prediction_dict["confidence"]
361+
362+
363+
def assert_category_prediction_matches_dict(
364+
prediction_instance, prediction_dict
365+
):
366+
assert_category_annotation_matches_dict(
367+
prediction_instance, prediction_dict
368+
)
369+
assert prediction_instance.confidence == prediction_dict["confidence"]

0 commit comments

Comments
 (0)