Skip to content

Commit d406e06

Browse files
committed
first version of including NER confusion matrix
1 parent 6804637 commit d406e06

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

labelbox/data/metrics/confusion_matrix/calculation.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import numpy as np
44

5-
from ..iou.calculation import _get_mask_pairs, _get_vector_pairs, miou
5+
from ..iou.calculation import _get_mask_pairs, _get_vector_pairs, _get_ner_pairs, miou
66
from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation,
7-
Mask, Geometry, Checklist, Radio,
7+
Mask, Geometry, Checklist, Radio, TextEntity,
88
ScalarMetricValue, ConfusionMatrixMetricValue)
99
from ..group import (get_feature_pairs, get_identifying_key, has_no_annotations,
1010
has_no_matching_annotations)
@@ -68,6 +68,8 @@ def feature_confusion_matrix(
6868
elif isinstance(predictions[0].value, Geometry):
6969
return vector_confusion_matrix(ground_truths, predictions,
7070
include_subclasses, iou)
71+
elif isinstance(predictions[0].value, TextEntity):
72+
pass #TODO
7173
elif isinstance(predictions[0], ClassificationAnnotation):
7274
return classification_confusion_matrix(ground_truths, predictions)
7375
else:
@@ -288,3 +290,25 @@ def mask_confusion_matrix(ground_truths: List[ObjectAnnotation],
288290
fn_mask = (prediction_np == 0) & (ground_truth_np == 1)
289291
tn_mask = prediction_np == ground_truth_np == 0
290292
return [np.sum(tp_mask), np.sum(fp_mask), np.sum(fn_mask), np.sum(tn_mask)]
293+
294+
295+
def ner_confusion_matrix(ground_truths: List[ObjectAnnotation],
296+
predictions: list[ObjectAnnotation],
297+
include_subclasses: bool,
298+
iou: float) -> Optional[ConfusionMatrixMetricValue]:
299+
"""Computes confusion matrix metric between two lists of TextEntity objects
300+
301+
TODO: work on include_subclasses logic
302+
303+
Args:
304+
ground_truths: List of ground truth mask annotations
305+
predictions: List of prediction mask annotations
306+
Returns:
307+
confusion matrix as a list: [TP,FP,TN,FN]
308+
"""
309+
if has_no_matching_annotations(ground_truths, predictions):
310+
return [0, int(len(predictions) > 0), 0, int(len(ground_truths) > 0)]
311+
elif has_no_annotations(ground_truths, predictions):
312+
return None
313+
pairs = _get_ner_pairs(ground_truths, predictions)
314+
return object_pair_confusion_matrix(pairs, include_subclasses, iou)

labelbox/data/metrics/iou/calculation.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from shapely.geometry import Polygon
55
import numpy as np
66

7+
from labelbox.data.annotation_types.ner import TextEntity
8+
79
from ..group import get_feature_pairs, get_identifying_key, has_no_annotations, has_no_matching_annotations
810
from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation,
911
Mask, Geometry, Point, Line, Checklist, Text,
@@ -269,3 +271,25 @@ def _ensure_valid_poly(poly):
269271
def _mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> ScalarMetricValue:
270272
"""Computes iou between two binary segmentation masks."""
271273
return np.sum(mask1 & mask2) / np.sum(mask1 | mask2)
274+
275+
276+
def _get_ner_pairs(
277+
ground_truths: List[ObjectAnnotation], predictions: List[ObjectAnnotation]
278+
) -> List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]]:
279+
"""Get iou score for all possible pairs of ground truths and predictions"""
280+
pairs = []
281+
for ground_truth, prediction in product(ground_truths, predictions):
282+
score = _ner_iou(ground_truth.value, prediction.value)
283+
pairs.append((ground_truth, prediction, score))
284+
# print(ground_truth.value.start, ground_truth.value.end,
285+
# prediction.value.start, prediction.value.end)
286+
return pairs
287+
288+
289+
def _ner_iou(ner1: TextEntity, ner2: TextEntity):
290+
"""Computes iou between two text entity annotations"""
291+
intersection_start, intersection_end = max(ner1.start, ner2.start), min(
292+
ner1.end, ner2.end)
293+
union_start, union_end = min(ner1.start,
294+
ner2.start), max(ner1.end, ner2.end)
295+
return (intersection_end - intersection_start) / (union_end - union_start)

0 commit comments

Comments
 (0)