2025-10-13 14:48:00 +02:00
|
|
|
from typing import List, Tuple
|
|
|
|
|
import torch
|
|
|
|
|
import numpy as np
|
2025-12-02 11:04:47 +01:00
|
|
|
from tqdm import tqdm
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
|
from mmp.a6.evallib import calculate_ap_pr
|
|
|
|
|
from ..a4.label_grid import iou
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
from ..a5.model import MmpNet
|
2025-12-02 11:04:47 +01:00
|
|
|
from ..a3.annotation import AnnotationRect, read_groundtruth_file
|
|
|
|
|
|
|
|
|
|
from .nms import non_maximum_suppression
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_inference(
|
|
|
|
|
model: MmpNet, images: torch.Tensor, device: torch.device, anchor_grid: np.ndarray
|
|
|
|
|
) -> List[List[Tuple[AnnotationRect, float]]]:
|
2025-12-02 11:04:47 +01:00
|
|
|
score_thresh = 0.5
|
|
|
|
|
nms_thresh = 0.3
|
|
|
|
|
|
|
|
|
|
model = model.to(device)
|
|
|
|
|
model.eval()
|
|
|
|
|
images = images.to(device)
|
|
|
|
|
anchor_grid = anchor_grid # shape [W, R, h, w, 4]
|
|
|
|
|
|
|
|
|
|
results = []
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
outputs = model(images) # (B, W, R, h, w, 2)
|
|
|
|
|
probs = torch.softmax(outputs, dim=-1)[..., 1] # (B, W, R, h, w)
|
|
|
|
|
probs_np = probs.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
batch_size = outputs.shape[0]
|
|
|
|
|
for b in range(batch_size):
|
|
|
|
|
detections = []
|
|
|
|
|
for idx in np.ndindex(anchor_grid.shape[:-1]):
|
|
|
|
|
score = probs_np[b][idx]
|
|
|
|
|
# if score >= score_thresh:
|
|
|
|
|
box = anchor_grid[idx]
|
|
|
|
|
rect = AnnotationRect.fromarray(box)
|
|
|
|
|
detections.append((rect, float(score)))
|
|
|
|
|
detections_nms = non_maximum_suppression(detections, nms_thresh)
|
|
|
|
|
results.append(detections_nms)
|
2025-10-13 14:48:00 +02:00
|
|
|
|
2025-12-02 11:04:47 +01:00
|
|
|
return results
|
2025-10-13 14:48:00 +02:00
|
|
|
|
2025-12-02 11:04:47 +01:00
|
|
|
|
|
|
|
|
def evaluate(
|
|
|
|
|
model: MmpNet, loader: DataLoader, device: torch.device, anchor_grid: np.ndarray
|
|
|
|
|
) -> float:
|
2025-10-13 14:48:00 +02:00
|
|
|
"""Evaluates a specified model on the whole validation dataset.
|
|
|
|
|
|
|
|
|
|
@return: AP for the validation set as a float.
|
|
|
|
|
|
|
|
|
|
You decide which arguments this function should receive
|
|
|
|
|
"""
|
2025-12-02 11:04:47 +01:00
|
|
|
|
|
|
|
|
path_to_data = ".data/mmp-public-3.2/train"
|
|
|
|
|
|
|
|
|
|
progress_bar = tqdm(loader, desc="Evaluation", unit="batch")
|
|
|
|
|
image_count = 0
|
|
|
|
|
ap_total = 0
|
|
|
|
|
for img_batch, _, id_batch in progress_bar:
|
|
|
|
|
inference = batch_inference(
|
|
|
|
|
anchor_grid=anchor_grid, device=device, images=img_batch, model=model
|
|
|
|
|
)
|
|
|
|
|
gts = get_gts_for_batch(id_batch=id_batch, gt_base_path=path_to_data)
|
|
|
|
|
|
|
|
|
|
dict_detections = {
|
|
|
|
|
img_id.item(): inference[idx] for idx, img_id in enumerate(id_batch)
|
|
|
|
|
}
|
|
|
|
|
dict_gt = {img_id.item(): gts[idx] for idx, img_id in enumerate(id_batch)}
|
|
|
|
|
average_prevision, precision, recall = calculate_ap_pr(dict_detections, dict_gt)
|
|
|
|
|
ap_total = (ap_total * image_count + average_prevision) / (
|
|
|
|
|
image_count + id_batch.shape[0]
|
|
|
|
|
)
|
|
|
|
|
image_count += id_batch.shape[0]
|
|
|
|
|
|
|
|
|
|
progress_bar.set_postfix(
|
|
|
|
|
{
|
|
|
|
|
"ap": ap_total,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return ap_total
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gts_for_batch(
|
|
|
|
|
id_batch: torch.Tensor, gt_base_path: str
|
|
|
|
|
) -> List[List[AnnotationRect]]:
|
|
|
|
|
return [
|
|
|
|
|
read_groundtruth_file(
|
|
|
|
|
os.path.join(gt_base_path, f"{str(img_id.item()).zfill(8)}.gt_data.txt")
|
|
|
|
|
)
|
|
|
|
|
for img_id in id_batch
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calc_tp_fp_fn(
|
|
|
|
|
detections: List[Tuple[AnnotationRect, float]],
|
|
|
|
|
gts: List[AnnotationRect],
|
|
|
|
|
iou_threshold: float = 0.5,
|
|
|
|
|
confidence_threshhold: float = 0.5,
|
|
|
|
|
) -> tuple[int, int, int]:
|
|
|
|
|
"""
|
|
|
|
|
Calculates precision and recall for object detection results on a single image.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
detections: List of (AnnotationRect, confidence) tuples representing predicted boxes and scores. Should be sorted by descending confidence.
|
|
|
|
|
gts: List of AnnotationRect for ground truth.
|
|
|
|
|
iou_threshold: Minimum IoU to consider a detection a true positive.
|
|
|
|
|
confidence_threshhold: Minimum confidence required to include a detection.
|
|
|
|
|
Returns:
|
|
|
|
|
num_tp: Number of true positives (int).
|
|
|
|
|
num_fp: Number of false positives (int).
|
|
|
|
|
num_fn: Number of false negatives (int).
|
|
|
|
|
"""
|
|
|
|
|
detections = [det for det in detections if det[1] >= confidence_threshhold]
|
|
|
|
|
detections.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
|
|
matches = set()
|
|
|
|
|
fp = 0
|
|
|
|
|
tp = 0
|
|
|
|
|
|
|
|
|
|
for det_rect, _ in detections:
|
|
|
|
|
iou_map = [iou(det_rect, gt_rect) for gt_rect in gts]
|
|
|
|
|
if len(iou_map) == 0:
|
|
|
|
|
fp += 1
|
|
|
|
|
continue
|
|
|
|
|
max_idx = np.argmax(iou_map)
|
|
|
|
|
if max_idx in matches or iou_map[max_idx] < iou_threshold:
|
|
|
|
|
fp += 1
|
|
|
|
|
continue
|
|
|
|
|
matches.add(max_idx)
|
|
|
|
|
tp += 1
|
|
|
|
|
|
|
|
|
|
fn = len(gts) - len(matches)
|
|
|
|
|
|
|
|
|
|
return tp, fp, fn
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_test(): # feel free to change the arguments
|
|
|
|
|
"""Generates predictions on the provided test dataset.
|
|
|
|
|
This function saves the predictions to a text file.
|
|
|
|
|
|
|
|
|
|
You decide which arguments this function should receive
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
"""Put the surrounding training code here. The code will probably look very similar to last assignment"""
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|