Files
mmp_wise2526_franksim/mmp/a6/evallib.py

103 lines
3.6 KiB
Python
Raw Normal View History

2025-10-13 14:48:00 +02:00
import numpy as np
from ..a4.label_grid import iou
def _assign_detections(det_boxes_scores, gt_boxes, min_iou):
"""Returns a list of 3-tuples (det_box, det_score, closest_gt_box)"""
assert min_iou > 0, "min_iou should be greater than zero"
det2gt = []
for det_box, det_score in det_boxes_scores:
# get all GTs that overlap at least by 0.5 and sort descending by overlap
best_iou = -1.0
closest_gt_box = None
for gt in gt_boxes:
curr_iou = iou(gt, det_box)
if curr_iou > best_iou:
best_iou = curr_iou
closest_gt_box = gt
if best_iou >= min_iou:
det2gt.append((det_box, det_score, closest_gt_box))
else:
det2gt.append((det_box, det_score, None))
return det2gt
def calculate_ap_pr(det_boxes_scores: dict, gt_boxes: dict, min_iou=0.5):
"""
Calculates average precision for the given detections and ground truths.
This function also calculates precision and recall values that can be used to plot a PR curve
@param det_boxes_scores: A dictionary that maps image numbers to a list of tuples,
containing detected AnnotationRect objects and a score.
@param gt_boxes: A dictionary that maps image numbers to the list of ground truth AnnotationRect objects.
"""
for dets in det_boxes_scores.values():
assert all(
isinstance(d, tuple) and len(d) == 2 for d in dets
), "Your detection boxes must have scores"
assert (
len(set(gt_boxes.keys()).intersection(det_boxes_scores.keys())) > 0
), "The two dictionaries have no common keys. Maybe you have selected the wrong dataset?"
gts_total = 0
dets2gts_flat = []
for img in gt_boxes:
gts_total += len(gt_boxes[img])
det2gts = _assign_detections(
det_boxes_scores.get(img, []), gt_boxes[img], min_iou
)
dets2gts_flat.extend(det2gts)
# sort by detection confidence
# p[1] is the score
dets2gts_flat.sort(key=lambda p: p[1], reverse=True)
tp = np.zeros(len(dets2gts_flat), dtype=np.float32)
fp = np.zeros(len(dets2gts_flat), dtype=np.float32)
fn_cum = np.zeros(len(dets2gts_flat), dtype=np.float32)
gts_seen = {None}
for idx, (_det, _det_score, gt) in enumerate(dets2gts_flat):
gt_prev_seen = gt in gts_seen
# If detection was assigned to GT that was not previously assigned to anything-> TP
tp[idx] = gt is not None and not gt_prev_seen
# If detection was not assigned to GT or detection was assigned to a previously assigned GT -> FP
fp[idx] = (gt is None) or gt_prev_seen
# all gts that have not been assigned yet -> FN
gts_seen.add(gt)
fn_cum[idx] = gts_total - (len(gts_seen) - 1)
# dets2gts_flat i sorted by score, therefore we can calculate tp and fp with cumsum
tp_cum = np.cumsum(tp)
fp_cum = np.cumsum(fp)
precision = tp_cum / (tp_cum + fp_cum)
recall = tp_cum / (tp_cum + fn_cum)
average_prevision = np.trapz(precision, recall)
return average_prevision, precision, recall
# def example():
# dboxes = {
# "31": [
# (AnnotationRect(1, 2, 3, 4), 0.234),
# (AnnotationRect(9, 8, 7, 6), 2.431093),
# ],
# "32": [
# (AnnotationRect(1, 2, 3, 4), 0.234),
# (AnnotationRect(9, 8, 7, 6), 2.431093),
# ],
# # ...
# }
# gboxes = {
# "31": [AnnotationRect(0, 3, 2, 3)],
# "32": [AnnotationRect(9, 9, 12, 23), AnnotationRect(1, 2, 40, 40)],
# }
# ap, _, _ = calculate_ap_pr(dboxes, gboxes)
# print("ap is", ap)