-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Description
def _compute_element_wise(self, y_pred, y_true):
batch_results = []
for gt, target in zip(y_true, y_pred):
target_boxes = target[self.geometry_name]
target_scores = target["scores"]
gt_boxes = gt[self.geometry_name]
pred_boxes = target_boxes[target_scores > self.score_threshold]
if self.metric == "accuracy":
det_accuracy = self._accuracy(gt_boxes, pred_boxes,
self.iou_threshold)
elif self.metric == "recall":
det_accuracy = self._recall(gt_boxes, pred_boxes,
self.iou_threshold)
batch_results.append(det_accuracy)
return torch.tensor(batch_results)
gt_boxes
tensor([[ 79.7891, 396.1453, 91.7831, 411.0619]])
target_boxes
tensor([0., 0., 0., 0.])
yields
pred_boxes = target_boxes[target_scores > self.score_threshold]
IndexError: The shape of the mask [1] at index 0 does not match the shape of the indexed tensor [4] at index 0
Metadata
Metadata
Assignees
Labels
No labels