from typing import Sequence import numpy as np from PIL import Image, ImageDraw import math from ..a3.annotation import AnnotationRect, read_groundtruth_file from .anchor_grid import get_anchor_grid def iou(rect1: AnnotationRect, rect2: AnnotationRect) -> float: x_left = max(rect1.x1, rect2.x1) y_top = max(rect1.y1, rect2.y1) x_right = min(rect1.x2, rect2.x2) y_bottom = min(rect1.y2, rect2.y2) intersection_area = max(x_right - x_left, 0) * max(y_bottom - y_top, 0) rect1_area = rect1.area() rect2_area = rect2.area() union_area = rect1_area + rect2_area - intersection_area return intersection_area / union_area def get_label_grid( anchor_grid: np.ndarray, gts: Sequence[AnnotationRect], min_iou: float ) -> tuple[np.ndarray, ...]: label_grid = np.empty(anchor_grid.shape[:-1], dtype=bool) for idx in np.ndindex(anchor_grid.shape[:-1]): for gt in gts: calculated_iou = iou(AnnotationRect.fromarray(anchor_grid[idx]), gt) label_grid[idx] = False if calculated_iou >= min_iou: label_grid[idx] = True break return label_grid def calculate_label_grid_for_image( image: str, scale_factor: float, anchor_widths: Sequence[float], aspect_ratios: Sequence[float], min_iou: float, ): im_width, im_height = Image.open(image).size anchor_grid = get_anchor_grid( num_rows=math.floor(im_height / scale_factor), num_cols=math.floor(im_width / scale_factor), scale_factor=scale_factor, anchor_widths=anchor_widths, aspect_ratios=aspect_ratios, ) gts = read_groundtruth_file(".data/mmp-public-3.2/train/02242500.gt_data.txt") label_grid = get_label_grid(anchor_grid=anchor_grid, gts=gts, min_iou=min_iou) annotations = [ AnnotationRect.fromarray(anchor_grid[idx]) for idx in np.ndindex(anchor_grid.shape[:-1]) if label_grid[idx] ] draw_annotation_rects(image, annotations, output_path="mmp/a4/output.jpg") def draw_annotation_rects( image: str, annotations: Sequence[AnnotationRect], rect_color=(255, 0, 0), rect_width=2, output_path="output.jpg", ): img = Image.open(image).convert("RGB") draw = ImageDraw.Draw(img) for annotation in annotations: draw.rectangle( [annotation.x1, annotation.y1, annotation.x2, annotation.y2], outline=rect_color, width=rect_width, ) img.save(output_path) def main(): calculate_label_grid_for_image( ".data/mmp-public-3.2/train/02242500.jpg", 8, anchor_widths=[16, 32, 64, 96, 128, 144, 150], aspect_ratios=[1, 4 / 3, 5 / 3, 2, 2.5, 3], min_iou=0.70, ) if __name__ == "__main__": main()