|
2 | 2 |
|
3 | 3 | import numpy as np
|
4 | 4 |
|
5 |
| -from ..iou.calculation import _get_mask_pairs, _get_vector_pairs, miou |
| 5 | +from ..iou.calculation import _get_mask_pairs, _get_vector_pairs, _get_ner_pairs, miou |
6 | 6 | from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation,
|
7 |
| - Mask, Geometry, Checklist, Radio, |
| 7 | + Mask, Geometry, Checklist, Radio, TextEntity, |
8 | 8 | ScalarMetricValue, ConfusionMatrixMetricValue)
|
9 | 9 | from ..group import (get_feature_pairs, get_identifying_key, has_no_annotations,
|
10 | 10 | has_no_matching_annotations)
|
@@ -68,6 +68,8 @@ def feature_confusion_matrix(
|
68 | 68 | elif isinstance(predictions[0].value, Geometry):
|
69 | 69 | return vector_confusion_matrix(ground_truths, predictions,
|
70 | 70 | include_subclasses, iou)
|
| 71 | + elif isinstance(predictions[0].value, TextEntity): |
| 72 | + pass #TODO |
71 | 73 | elif isinstance(predictions[0], ClassificationAnnotation):
|
72 | 74 | return classification_confusion_matrix(ground_truths, predictions)
|
73 | 75 | else:
|
@@ -288,3 +290,25 @@ def mask_confusion_matrix(ground_truths: List[ObjectAnnotation],
|
288 | 290 | fn_mask = (prediction_np == 0) & (ground_truth_np == 1)
|
289 | 291 | tn_mask = prediction_np == ground_truth_np == 0
|
290 | 292 | return [np.sum(tp_mask), np.sum(fp_mask), np.sum(fn_mask), np.sum(tn_mask)]
|
| 293 | + |
| 294 | + |
| 295 | +def ner_confusion_matrix(ground_truths: List[ObjectAnnotation], |
| 296 | + predictions: list[ObjectAnnotation], |
| 297 | + include_subclasses: bool, |
| 298 | + iou: float) -> Optional[ConfusionMatrixMetricValue]: |
| 299 | + """Computes confusion matrix metric between two lists of TextEntity objects |
| 300 | + |
| 301 | + TODO: work on include_subclasses logic |
| 302 | +
|
| 303 | + Args: |
| 304 | + ground_truths: List of ground truth mask annotations |
| 305 | + predictions: List of prediction mask annotations |
| 306 | + Returns: |
| 307 | + confusion matrix as a list: [TP,FP,TN,FN] |
| 308 | + """ |
| 309 | + if has_no_matching_annotations(ground_truths, predictions): |
| 310 | + return [0, int(len(predictions) > 0), 0, int(len(ground_truths) > 0)] |
| 311 | + elif has_no_annotations(ground_truths, predictions): |
| 312 | + return None |
| 313 | + pairs = _get_ner_pairs(ground_truths, predictions) |
| 314 | + return object_pair_confusion_matrix(pairs, include_subclasses, iou) |
0 commit comments