from typing import List, Tuple import torch import numpy as np from tqdm import tqdm import os from torch.utils.data import DataLoader from mmp.a6.evallib import calculate_ap_pr from ..a4.label_grid import iou from ..a5.model import MmpNet from ..a3.annotation import AnnotationRect, read_groundtruth_file from .nms import non_maximum_suppression def batch_inference( model: MmpNet, images: torch.Tensor, device: torch.device, anchor_grid: np.ndarray ) -> List[List[Tuple[AnnotationRect, float]]]: score_thresh = 0.5 nms_thresh = 0.3 model = model.to(device) model.eval() images = images.to(device) anchor_grid = anchor_grid # shape [W, R, h, w, 4] results = [] with torch.no_grad(): outputs = model(images) # (B, W, R, h, w, 2) probs = torch.softmax(outputs, dim=-1)[..., 1] # (B, W, R, h, w) probs_np = probs.cpu().numpy() batch_size = outputs.shape[0] for b in range(batch_size): detections = [] for idx in np.ndindex(anchor_grid.shape[:-1]): score = probs_np[b][idx] # if score >= score_thresh: box = anchor_grid[idx] rect = AnnotationRect.fromarray(box) detections.append((rect, float(score))) detections_nms = non_maximum_suppression(detections, nms_thresh) results.append(detections_nms) return results def evaluate( model: MmpNet, loader: DataLoader, device: torch.device, anchor_grid: np.ndarray ) -> float: """Evaluates a specified model on the whole validation dataset. @return: AP for the validation set as a float. You decide which arguments this function should receive """ path_to_data = ".data/mmp-public-3.2/train" progress_bar = tqdm(loader, desc="Evaluation", unit="batch") image_count = 0 ap_total = 0 for img_batch, _, id_batch in progress_bar: inference = batch_inference( anchor_grid=anchor_grid, device=device, images=img_batch, model=model ) gts = get_gts_for_batch(id_batch=id_batch, gt_base_path=path_to_data) dict_detections = { img_id.item(): inference[idx] for idx, img_id in enumerate(id_batch) } dict_gt = {img_id.item(): gts[idx] for idx, img_id in enumerate(id_batch)} average_prevision, precision, recall = calculate_ap_pr(dict_detections, dict_gt) ap_total = (ap_total * image_count + average_prevision) / ( image_count + id_batch.shape[0] ) image_count += id_batch.shape[0] progress_bar.set_postfix( { "ap": ap_total, } ) return ap_total def get_gts_for_batch( id_batch: torch.Tensor, gt_base_path: str ) -> List[List[AnnotationRect]]: return [ read_groundtruth_file( os.path.join(gt_base_path, f"{str(img_id.item()).zfill(8)}.gt_data.txt") ) for img_id in id_batch ] def calc_tp_fp_fn( detections: List[Tuple[AnnotationRect, float]], gts: List[AnnotationRect], iou_threshold: float = 0.5, confidence_threshhold: float = 0.5, ) -> tuple[int, int, int]: """ Calculates precision and recall for object detection results on a single image. Args: detections: List of (AnnotationRect, confidence) tuples representing predicted boxes and scores. Should be sorted by descending confidence. gts: List of AnnotationRect for ground truth. iou_threshold: Minimum IoU to consider a detection a true positive. confidence_threshhold: Minimum confidence required to include a detection. Returns: num_tp: Number of true positives (int). num_fp: Number of false positives (int). num_fn: Number of false negatives (int). """ detections = [det for det in detections if det[1] >= confidence_threshhold] detections.sort(key=lambda x: x[1], reverse=True) matches = set() fp = 0 tp = 0 for det_rect, _ in detections: iou_map = [iou(det_rect, gt_rect) for gt_rect in gts] if len(iou_map) == 0: fp += 1 continue max_idx = np.argmax(iou_map) if max_idx in matches or iou_map[max_idx] < iou_threshold: fp += 1 continue matches.add(max_idx) tp += 1 fn = len(gts) - len(matches) return tp, fp, fn def evaluate_test(): # feel free to change the arguments """Generates predictions on the provided test dataset. This function saves the predictions to a text file. You decide which arguments this function should receive """ raise NotImplementedError() def main(): """Put the surrounding training code here. The code will probably look very similar to last assignment""" raise NotImplementedError() if __name__ == "__main__": main()