Skip to content

fix(models/utils.py): refactor nms() implmentation, same with torchvi… #192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 48 additions & 44 deletions models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,50 +127,54 @@ def batched_nms(boxes: ndarray,
return keep_boxes, keep_scores, keep_labels


def nms(boxes: ndarray,
scores: ndarray,
iou_thres: float = 0.65,
conf_thres: float = 0.25):
labels = np.argmax(scores, axis=-1)
scores = np.max(scores, axis=-1)

cand = scores > conf_thres
boxes = boxes[cand]
scores = scores[cand]
labels = labels[cand]

keep_boxes = []
keep_scores = []
keep_labels = []

idxs = scores.argsort()
while idxs.size > 0:
max_score_index = idxs[-1]
max_box = boxes[max_score_index:max_score_index + 1]
max_score = scores[max_score_index:max_score_index + 1]
max_label = np.array([labels[max_score_index]], dtype=np.int32)
keep_boxes.append(max_box)
keep_scores.append(max_score)
keep_labels.append(max_label)
if idxs.size == 1:
break
idxs = idxs[:-1]
other_boxes = boxes[idxs]
ious = bbox_iou(max_box, other_boxes)
iou_mask = ious < iou_thres
idxs = idxs[iou_mask]

if len(keep_boxes) == 0:
keep_boxes = np.empty((0, 4), dtype=np.float32)
keep_scores = np.empty((0, ), dtype=np.float32)
keep_labels = np.empty((0, ), dtype=np.float32)

else:
keep_boxes = np.concatenate(keep_boxes, axis=0)
keep_scores = np.concatenate(keep_scores, axis=0)
keep_labels = np.concatenate(keep_labels, axis=0)

return keep_boxes, keep_scores, keep_labels
def nms(bboxes: ndarray, scores: ndarray, iou_thresh: float):
"""
Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU).

NMS iteratively removes lower scoring boxes which have an
IoU greater than iou_threshold with another (higher scoring)
box.

If multiple boxes have the exact same score and satisfy the IoU
criterion with respect to a reference box, the selected box is
not guaranteed to be the same between CPU and GPU. This is similar
to the behavior of argsort in PyTorch when repeated values are present.

Args:
bboxes (ndarray[N, 4])): boxes to perform NMS on. They
are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
``0 <= y1 < y2``.
scores (ndarray[N]): scores for each one of the boxes
iou_thresh (float): discards all overlapping boxes with IoU > iou_threshold

Returns:
ndarray: int64 tensor with the indices of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
x1 = bboxes[:, 0]
y1 = bboxes[:, 1]
x2 = bboxes[:, 2]
y2 = bboxes[:, 3]
areas = (y2 - y1) * (x2 - x1)

result = []
index = scores.argsort()[::-1]
while index.size > 0:
i = index[0]
result.append(i)

x11 = np.maximum(x1[i], x1[index[1:]])
y11 = np.maximum(y1[i], y1[index[1:]])
x22 = np.minimum(x2[i], x2[index[1:]])
y22 = np.minimum(y2[i], y2[index[1:]])
w = np.maximum(0, x22 - x11 + 1)
h = np.maximum(0, y22 - y11 + 1)
overlaps = w * h
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
idx = np.where(ious <= iou_thresh)[0]
index = index[idx + 1]
return np.array(result, dtype=int)


def path_to_list(images_path: Union[str, Path]) -> List:
Expand Down