adds nms and eval

This commit is contained in:
franksim
2025-12-02 11:04:47 +01:00
parent 3b6a588719
commit a6f70005f2
9 changed files with 428 additions and 985 deletions

0
mmp/a6/__init__.py Normal file
View File

View File

@@ -1,25 +1,144 @@
from typing import List, Tuple
import torch
import numpy as np
from tqdm import tqdm
import os
from torch.utils.data import DataLoader
from mmp.a6.evallib import calculate_ap_pr
from ..a4.label_grid import iou
from ..a5.model import MmpNet
from ..a3.annotation import AnnotationRect
from ..a3.annotation import AnnotationRect, read_groundtruth_file
from .nms import non_maximum_suppression
def batch_inference(
model: MmpNet, images: torch.Tensor, device: torch.device, anchor_grid: np.ndarray
) -> List[List[Tuple[AnnotationRect, float]]]:
raise NotImplementedError()
score_thresh = 0.5
nms_thresh = 0.3
model = model.to(device)
model.eval()
images = images.to(device)
anchor_grid = anchor_grid # shape [W, R, h, w, 4]
results = []
with torch.no_grad():
outputs = model(images) # (B, W, R, h, w, 2)
probs = torch.softmax(outputs, dim=-1)[..., 1] # (B, W, R, h, w)
probs_np = probs.cpu().numpy()
batch_size = outputs.shape[0]
for b in range(batch_size):
detections = []
for idx in np.ndindex(anchor_grid.shape[:-1]):
score = probs_np[b][idx]
# if score >= score_thresh:
box = anchor_grid[idx]
rect = AnnotationRect.fromarray(box)
detections.append((rect, float(score)))
detections_nms = non_maximum_suppression(detections, nms_thresh)
results.append(detections_nms)
return results
def evaluate() -> float: # feel free to change the arguments
def evaluate(
model: MmpNet, loader: DataLoader, device: torch.device, anchor_grid: np.ndarray
) -> float:
"""Evaluates a specified model on the whole validation dataset.
@return: AP for the validation set as a float.
You decide which arguments this function should receive
"""
raise NotImplementedError()
path_to_data = ".data/mmp-public-3.2/train"
progress_bar = tqdm(loader, desc="Evaluation", unit="batch")
image_count = 0
ap_total = 0
for img_batch, _, id_batch in progress_bar:
inference = batch_inference(
anchor_grid=anchor_grid, device=device, images=img_batch, model=model
)
gts = get_gts_for_batch(id_batch=id_batch, gt_base_path=path_to_data)
dict_detections = {
img_id.item(): inference[idx] for idx, img_id in enumerate(id_batch)
}
dict_gt = {img_id.item(): gts[idx] for idx, img_id in enumerate(id_batch)}
average_prevision, precision, recall = calculate_ap_pr(dict_detections, dict_gt)
ap_total = (ap_total * image_count + average_prevision) / (
image_count + id_batch.shape[0]
)
image_count += id_batch.shape[0]
progress_bar.set_postfix(
{
"ap": ap_total,
}
)
return ap_total
def get_gts_for_batch(
id_batch: torch.Tensor, gt_base_path: str
) -> List[List[AnnotationRect]]:
return [
read_groundtruth_file(
os.path.join(gt_base_path, f"{str(img_id.item()).zfill(8)}.gt_data.txt")
)
for img_id in id_batch
]
def calc_tp_fp_fn(
detections: List[Tuple[AnnotationRect, float]],
gts: List[AnnotationRect],
iou_threshold: float = 0.5,
confidence_threshhold: float = 0.5,
) -> tuple[int, int, int]:
"""
Calculates precision and recall for object detection results on a single image.
Args:
detections: List of (AnnotationRect, confidence) tuples representing predicted boxes and scores. Should be sorted by descending confidence.
gts: List of AnnotationRect for ground truth.
iou_threshold: Minimum IoU to consider a detection a true positive.
confidence_threshhold: Minimum confidence required to include a detection.
Returns:
num_tp: Number of true positives (int).
num_fp: Number of false positives (int).
num_fn: Number of false negatives (int).
"""
detections = [det for det in detections if det[1] >= confidence_threshhold]
detections.sort(key=lambda x: x[1], reverse=True)
matches = set()
fp = 0
tp = 0
for det_rect, _ in detections:
iou_map = [iou(det_rect, gt_rect) for gt_rect in gts]
if len(iou_map) == 0:
fp += 1
continue
max_idx = np.argmax(iou_map)
if max_idx in matches or iou_map[max_idx] < iou_threshold:
fp += 1
continue
matches.add(max_idx)
tp += 1
fn = len(gts) - len(matches)
return tp, fp, fn
def evaluate_test(): # feel free to change the arguments

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,9 @@
import os
from typing import List, Sequence, Tuple
from ..a3.annotation import AnnotationRect
from ..a4.label_grid import iou, draw_annotation_rects
from collections import defaultdict
def non_maximum_suppression(
@@ -12,4 +15,68 @@ def non_maximum_suppression(
@return: A list of tuples of the remaining boxes after NMS together with their scores
"""
raise NotImplementedError()
if not boxes_scores:
return []
# Sort the boxes by score in descending order
boxes_scores_sorted = sorted(boxes_scores, key=lambda bs: bs[1], reverse=True)
result = []
while boxes_scores_sorted:
# Select the box with highest score and remove it from the list
curr_box, curr_score = boxes_scores_sorted.pop(0)
result.append((curr_box, curr_score))
# Remove boxes with IoU > threshold
new_boxes = []
for box, score in boxes_scores_sorted:
if iou(curr_box, box) <= threshold:
new_boxes.append((box, score))
boxes_scores_sorted = new_boxes
return result
def read_boxes_from_file(filepath: str) -> List[Tuple[str, AnnotationRect, float]]:
"""
Reads a file containing bounding boxes and scores in the format:
{image_number} {x1} {y1} {x2} {y2} {score}
Returns a list of tuples: (image_number, x1, y1, x2, y2, score)
"""
boxes: List[Tuple[AnnotationRect, float]] = []
with open(filepath, "r") as f:
for line in f:
parts = line.strip().split()
if len(parts) != 6:
continue
img_id = parts[0]
x1, y1, x2, y2 = map(int, parts[1:5])
annotation_rect = AnnotationRect(x1, y1, x2, y2)
score = float(parts[5])
boxes.append((img_id, annotation_rect, score))
return boxes
def main():
boxes = read_boxes_from_file("mmp/a6/model_output.txt")
grouped = defaultdict(list)
for image_id, rect, score in boxes:
grouped[image_id].append((rect, score))
for image_id, rects_scores in grouped.items():
filtered_boxes = non_maximum_suppression(rects_scores, 0.3)
annotation_rects = [rect for rect, score in filtered_boxes if score > 0.5]
input_path = f".data/mmp-public-3.2/test/{image_id}.jpg"
output_path = f"mmp/a6/nms_output_{image_id}.png"
if not os.path.exists(input_path):
continue
draw_annotation_rects(
input_path,
annotation_rects,
rect_color=(255, 0, 0),
rect_width=2,
output_path=output_path,
)
if __name__ == "__main__":
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 359 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 231 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 232 KiB