Skip to content

Commit 8eac828

Browse files
authored
[Validate] Introduce better filtering errors (#321)
* Raise error if everything is filtered * Fix error reporting * Bump and changelog
1 parent cf51ddb commit 8eac828

File tree

9 files changed

+283
-60
lines changed

9 files changed

+283
-60
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
fail_fast: true
1+
fail_fast: false
22
repos:
33
- repo: local
44
hooks:

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.14.2](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.2) - 2022-06-21
9+
10+
### Fixed
11+
- Better error reporting when everything is filtered out by a filter statement in a Validate evaluation function
12+
813
## [0.14.1](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.1) - 2022-06-20
914

1015
### Fixed

nucleus/annotation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,9 @@ def add_annotations(self, annotations: List[Annotation]):
939939
), f"Unexpected annotation type: {type(annotation)}"
940940
self.segmentation_annotations.append(annotation)
941941

942+
def items(self):
943+
return self.__dict__.items()
944+
942945
def __len__(self):
943946
return (
944947
len(self.box_annotations)

nucleus/metrics/base.py

Lines changed: 34 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
from typing import Iterable, List, Optional, Union
55

66
from nucleus.annotation import AnnotationList
7+
from nucleus.metrics.errors import EverythingFilteredError
78
from nucleus.metrics.filtering import (
89
ListOfAndFilters,
910
ListOfOrAndFilters,
10-
apply_filters,
11+
compose_helpful_filtering_error,
12+
filter_annotation_list,
13+
filter_prediction_list,
1114
)
1215
from nucleus.prediction import PredictionList
1316

@@ -133,64 +136,16 @@ def call_metric(
133136
def __call__(
134137
self, annotations: AnnotationList, predictions: PredictionList
135138
) -> MetricResult:
136-
annotations = self._filter_annotations(annotations)
137-
predictions = self._filter_predictions(predictions)
138-
return self.call_metric(annotations, predictions)
139-
140-
def _filter_annotations(self, annotations: AnnotationList):
141-
if (
142-
self.annotation_filters is None
143-
or len(self.annotation_filters) == 0
144-
):
145-
return annotations
146-
annotations.box_annotations = apply_filters(
147-
annotations.box_annotations, self.annotation_filters
148-
)
149-
annotations.line_annotations = apply_filters(
150-
annotations.line_annotations, self.annotation_filters
151-
)
152-
annotations.polygon_annotations = apply_filters(
153-
annotations.polygon_annotations, self.annotation_filters
154-
)
155-
annotations.cuboid_annotations = apply_filters(
156-
annotations.cuboid_annotations, self.annotation_filters
157-
)
158-
annotations.category_annotations = apply_filters(
159-
annotations.category_annotations, self.annotation_filters
160-
)
161-
annotations.multi_category_annotations = apply_filters(
162-
annotations.multi_category_annotations, self.annotation_filters
163-
)
164-
annotations.segmentation_annotations = apply_filters(
165-
annotations.segmentation_annotations, self.annotation_filters
139+
filtered_anns = filter_annotation_list(
140+
annotations, self.annotation_filters
166141
)
167-
return annotations
168-
169-
def _filter_predictions(self, predictions: PredictionList):
170-
if (
171-
self.prediction_filters is None
172-
or len(self.prediction_filters) == 0
173-
):
174-
return predictions
175-
predictions.box_predictions = apply_filters(
176-
predictions.box_predictions, self.prediction_filters
142+
filtered_preds = filter_prediction_list(
143+
predictions, self.prediction_filters
177144
)
178-
predictions.line_predictions = apply_filters(
179-
predictions.line_predictions, self.prediction_filters
145+
self._raise_if_everything_filtered(
146+
annotations, filtered_anns, predictions, filtered_preds
180147
)
181-
predictions.polygon_predictions = apply_filters(
182-
predictions.polygon_predictions, self.prediction_filters
183-
)
184-
predictions.cuboid_predictions = apply_filters(
185-
predictions.cuboid_predictions, self.prediction_filters
186-
)
187-
predictions.category_predictions = apply_filters(
188-
predictions.category_predictions, self.prediction_filters
189-
)
190-
predictions.segmentation_predictions = apply_filters(
191-
predictions.segmentation_predictions, self.prediction_filters
192-
)
193-
return predictions
148+
return self.call_metric(annotations, predictions)
194149

195150
@abstractmethod
196151
def aggregate_score(self, results: List[MetricResult]) -> ScalarResult:
@@ -215,3 +170,26 @@ def aggregate_score(self, results: List[MetricResult]) -> ScalarResult:
215170
return ScalarResult(r2_score)
216171
217172
"""
173+
174+
def _raise_if_everything_filtered(
175+
self,
176+
annotations: AnnotationList,
177+
filtered_annotations: AnnotationList,
178+
predictions: PredictionList,
179+
filtered_predictions: PredictionList,
180+
):
181+
msg = []
182+
if len(filtered_annotations) == 0:
183+
msg.extend(
184+
compose_helpful_filtering_error(
185+
annotations, self.annotation_filters
186+
)
187+
)
188+
if len(filtered_predictions) == 0:
189+
msg.extend(
190+
compose_helpful_filtering_error(
191+
predictions, self.prediction_filters
192+
)
193+
)
194+
if msg:
195+
raise EverythingFilteredError("\n".join(msg))

nucleus/metrics/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,7 @@ def __init__(
55
):
66
self.message = message
77
super().__init__(self.message)
8+
9+
10+
class EverythingFilteredError(Exception):
11+
pass

nucleus/metrics/filtering.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import enum
23
import functools
34
import logging
@@ -7,13 +8,18 @@
78
Iterable,
89
List,
910
NamedTuple,
11+
Optional,
1012
Sequence,
1113
Set,
1214
Tuple,
1315
Union,
1416
)
1517

18+
from rich.console import Console
19+
from rich.table import Table
20+
1621
from nucleus.annotation import (
22+
AnnotationList,
1723
BoxAnnotation,
1824
CategoryAnnotation,
1925
CuboidAnnotation,
@@ -29,6 +35,7 @@
2935
CuboidPrediction,
3036
LinePrediction,
3137
PolygonPrediction,
38+
PredictionList,
3239
SegmentationPrediction,
3340
)
3441

@@ -568,3 +575,147 @@ def ensureDNFFilters(filters) -> OrAndDNFFilters:
568575
formatted_filter.append(and_chain)
569576
filters = formatted_filter
570577
return filters
578+
579+
580+
def pretty_format_filters_with_or_and(
581+
filters: Optional[Union[ListOfOrAndFilters, ListOfAndFilters]]
582+
):
583+
if filters is None:
584+
return "No filters applied!"
585+
dnf_filters = ensureDNFFilters(filters)
586+
or_branches = []
587+
for or_branch in dnf_filters:
588+
and_statements = []
589+
for and_branch in or_branch:
590+
if and_branch.type == FilterType.FIELD:
591+
class_name = "FieldFilter"
592+
elif and_branch.type == FilterType.METADATA:
593+
class_name = "MetadataFilter"
594+
elif and_branch.type == FilterType.SEGMENT_FIELD:
595+
class_name = "SegmentFieldFilter"
596+
elif and_branch.type == FilterType.SEGMENT_METADATA:
597+
class_name = "SegmentMetadataFilter"
598+
else:
599+
raise RuntimeError(
600+
f"Un-handled filter type: {and_branch.type}"
601+
)
602+
op = (
603+
and_branch.op.value
604+
if isinstance(and_branch.op, FilterOp)
605+
else and_branch.op
606+
)
607+
value_formatted = (
608+
f'"{and_branch.value}"'
609+
if isinstance(and_branch.value, str)
610+
else f"{and_branch.value}".replace("'", '"')
611+
)
612+
statement = (
613+
f'{class_name}("{and_branch.key}", "{op}", {value_formatted})'
614+
)
615+
and_statements.append(statement)
616+
617+
or_branches.append(and_statements)
618+
619+
and_to_join = []
620+
for and_statements in or_branches:
621+
joined_and = " and ".join(and_statements)
622+
if len(or_branches) > 1 and len(and_statements) > 1:
623+
joined_and = "(" + joined_and + ")"
624+
and_to_join.append(joined_and)
625+
626+
full_statement = " or ".join(and_to_join)
627+
return full_statement
628+
629+
630+
def compose_helpful_filtering_error(
631+
ann_or_pred_list: Union[AnnotationList, PredictionList], filters
632+
) -> List[str]:
633+
prefix = (
634+
"Annotations"
635+
if isinstance(ann_or_pred_list, AnnotationList)
636+
else "Predictions"
637+
)
638+
msg = []
639+
msg.append(f"{prefix}: All items filtered out by:")
640+
msg.append(f" {pretty_format_filters_with_or_and(filters)}")
641+
msg.append("")
642+
console = Console()
643+
table = Table(
644+
"Type",
645+
"Count",
646+
"Labels",
647+
title=f"Original {prefix}",
648+
title_justify="left",
649+
)
650+
for ann_or_pred_type, items in ann_or_pred_list.items():
651+
if items and isinstance(
652+
items[-1], (SegmentationAnnotation, SegmentationPrediction)
653+
):
654+
labels = set()
655+
for seg in items:
656+
labels.update(set(s.label for s in seg.annotations))
657+
else:
658+
labels = set(a.label for a in items)
659+
if items:
660+
table.add_row(ann_or_pred_type, str(len(items)), str(list(labels)))
661+
with console.capture() as capture:
662+
console.print(table)
663+
msg.append(capture.get())
664+
return msg
665+
666+
667+
def filter_annotation_list(
668+
annotations: AnnotationList, annotation_filters
669+
) -> AnnotationList:
670+
annotations = copy.deepcopy(annotations)
671+
if annotation_filters is None or len(annotation_filters) == 0:
672+
return annotations
673+
annotations.box_annotations = apply_filters(
674+
annotations.box_annotations, annotation_filters
675+
)
676+
annotations.line_annotations = apply_filters(
677+
annotations.line_annotations, annotation_filters
678+
)
679+
annotations.polygon_annotations = apply_filters(
680+
annotations.polygon_annotations, annotation_filters
681+
)
682+
annotations.cuboid_annotations = apply_filters(
683+
annotations.cuboid_annotations, annotation_filters
684+
)
685+
annotations.category_annotations = apply_filters(
686+
annotations.category_annotations, annotation_filters
687+
)
688+
annotations.multi_category_annotations = apply_filters(
689+
annotations.multi_category_annotations, annotation_filters
690+
)
691+
annotations.segmentation_annotations = apply_filters(
692+
annotations.segmentation_annotations, annotation_filters
693+
)
694+
return annotations
695+
696+
697+
def filter_prediction_list(
698+
predictions: PredictionList, prediction_filters
699+
) -> PredictionList:
700+
predictions = copy.deepcopy(predictions)
701+
if prediction_filters is None or len(prediction_filters) == 0:
702+
return predictions
703+
predictions.box_predictions = apply_filters(
704+
predictions.box_predictions, prediction_filters
705+
)
706+
predictions.line_predictions = apply_filters(
707+
predictions.line_predictions, prediction_filters
708+
)
709+
predictions.polygon_predictions = apply_filters(
710+
predictions.polygon_predictions, prediction_filters
711+
)
712+
predictions.cuboid_predictions = apply_filters(
713+
predictions.cuboid_predictions, prediction_filters
714+
)
715+
predictions.category_predictions = apply_filters(
716+
predictions.category_predictions, prediction_filters
717+
)
718+
predictions.segmentation_predictions = apply_filters(
719+
predictions.segmentation_predictions, prediction_filters
720+
)
721+
return predictions

nucleus/prediction.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,9 @@ class PredictionList:
600600
default_factory=list
601601
)
602602

603+
def items(self):
604+
return self.__dict__.items()
605+
603606
def add_predictions(self, predictions: List[Prediction]):
604607
for prediction in predictions:
605608
if isinstance(prediction, BoxPrediction):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.14.1"
24+
version = "0.14.2"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

0 commit comments

Comments
 (0)