83 lines
2.7 KiB
Python
83 lines
2.7 KiB
Python
import os
|
|
from typing import List, Sequence, Tuple
|
|
|
|
from ..a3.annotation import AnnotationRect
|
|
from ..a4.label_grid import iou, draw_annotation_rects
|
|
from collections import defaultdict
|
|
|
|
|
|
def non_maximum_suppression(
|
|
boxes_scores: Sequence[Tuple[AnnotationRect, float]], threshold: float
|
|
) -> List[Tuple[AnnotationRect, float]]:
|
|
"""Exercise 6.1
|
|
@param boxes_scores: Sequence of tuples of annotations and scores
|
|
@param threshold: Threshold for NMS
|
|
|
|
@return: A list of tuples of the remaining boxes after NMS together with their scores
|
|
"""
|
|
if not boxes_scores:
|
|
return []
|
|
|
|
# Sort the boxes by score in descending order
|
|
boxes_scores_sorted = sorted(boxes_scores, key=lambda bs: bs[1], reverse=True)
|
|
result = []
|
|
while boxes_scores_sorted:
|
|
# Select the box with highest score and remove it from the list
|
|
curr_box, curr_score = boxes_scores_sorted.pop(0)
|
|
result.append((curr_box, curr_score))
|
|
# Remove boxes with IoU > threshold
|
|
new_boxes = []
|
|
for box, score in boxes_scores_sorted:
|
|
if iou(curr_box, box) <= threshold:
|
|
new_boxes.append((box, score))
|
|
boxes_scores_sorted = new_boxes
|
|
return result
|
|
|
|
|
|
def read_boxes_from_file(filepath: str) -> List[Tuple[str, AnnotationRect, float]]:
|
|
"""
|
|
Reads a file containing bounding boxes and scores in the format:
|
|
{image_number} {x1} {y1} {x2} {y2} {score}
|
|
Returns a list of tuples: (image_number, x1, y1, x2, y2, score)
|
|
"""
|
|
boxes: List[Tuple[AnnotationRect, float]] = []
|
|
with open(filepath, "r") as f:
|
|
for line in f:
|
|
parts = line.strip().split()
|
|
if len(parts) != 6:
|
|
continue
|
|
img_id = parts[0]
|
|
x1, y1, x2, y2 = map(int, parts[1:5])
|
|
annotation_rect = AnnotationRect(x1, y1, x2, y2)
|
|
score = float(parts[5])
|
|
boxes.append((img_id, annotation_rect, score))
|
|
return boxes
|
|
|
|
|
|
def main():
|
|
boxes = read_boxes_from_file("mmp/a6/model_output.txt")
|
|
|
|
grouped = defaultdict(list)
|
|
for image_id, rect, score in boxes:
|
|
grouped[image_id].append((rect, score))
|
|
|
|
for image_id, rects_scores in grouped.items():
|
|
filtered_boxes = non_maximum_suppression(rects_scores, 0.3)
|
|
annotation_rects = [rect for rect, score in filtered_boxes if score > 0.5]
|
|
input_path = f".data/mmp-public-3.2/test/{image_id}.jpg"
|
|
output_path = f"mmp/a6/nms_output_{image_id}.png"
|
|
if not os.path.exists(input_path):
|
|
continue
|
|
|
|
draw_annotation_rects(
|
|
input_path,
|
|
annotation_rects,
|
|
rect_color=(255, 0, 0),
|
|
rect_width=2,
|
|
output_path=output_path,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|