import os from typing import List, Sequence, Tuple from ..a3.annotation import AnnotationRect from ..a4.label_grid import iou, draw_annotation_rects from collections import defaultdict def non_maximum_suppression( boxes_scores: Sequence[Tuple[AnnotationRect, float]], threshold: float ) -> List[Tuple[AnnotationRect, float]]: """Exercise 6.1 @param boxes_scores: Sequence of tuples of annotations and scores @param threshold: Threshold for NMS @return: A list of tuples of the remaining boxes after NMS together with their scores """ if not boxes_scores: return [] # Sort the boxes by score in descending order boxes_scores_sorted = sorted(boxes_scores, key=lambda bs: bs[1], reverse=True) result = [] while boxes_scores_sorted: # Select the box with highest score and remove it from the list curr_box, curr_score = boxes_scores_sorted.pop(0) result.append((curr_box, curr_score)) # Remove boxes with IoU > threshold new_boxes = [] for box, score in boxes_scores_sorted: if iou(curr_box, box) <= threshold: new_boxes.append((box, score)) boxes_scores_sorted = new_boxes return result def read_boxes_from_file(filepath: str) -> List[Tuple[str, AnnotationRect, float]]: """ Reads a file containing bounding boxes and scores in the format: {image_number} {x1} {y1} {x2} {y2} {score} Returns a list of tuples: (image_number, x1, y1, x2, y2, score) """ boxes: List[Tuple[AnnotationRect, float]] = [] with open(filepath, "r") as f: for line in f: parts = line.strip().split() if len(parts) != 6: continue img_id = parts[0] x1, y1, x2, y2 = map(int, parts[1:5]) annotation_rect = AnnotationRect(x1, y1, x2, y2) score = float(parts[5]) boxes.append((img_id, annotation_rect, score)) return boxes def main(): boxes = read_boxes_from_file("mmp/a6/model_output.txt") grouped = defaultdict(list) for image_id, rect, score in boxes: grouped[image_id].append((rect, score)) for image_id, rects_scores in grouped.items(): filtered_boxes = non_maximum_suppression(rects_scores, 0.3) annotation_rects = [rect for rect, score in filtered_boxes if score > 0.5] input_path = f".data/mmp-public-3.2/test/{image_id}.jpg" output_path = f"mmp/a6/nms_output_{image_id}.png" if not os.path.exists(input_path): continue draw_annotation_rects( input_path, annotation_rects, rect_color=(255, 0, 0), rect_width=2, output_path=output_path, ) if __name__ == "__main__": main()