diff --git a/mmp/a4/anchor_grid.py b/mmp/a4/anchor_grid.py index 3c54c51..0511d16 100644 --- a/mmp/a4/anchor_grid.py +++ b/mmp/a4/anchor_grid.py @@ -9,4 +9,24 @@ def get_anchor_grid( anchor_widths: Sequence[float], aspect_ratios: Sequence[float], ) -> np.ndarray: - raise NotImplementedError() + anchor_grid = np.empty( + [len(width), len(ratio), num_rows, num_cols, 4], dtype=int) + for width_idx, width in enumerate(anchor_widths): + for ratio_idx, ratio in enumerate(aspect_ratios): + for row in range(num_rows): + for col in range(num_cols): + anchor_point = ( + col * scale_factor + scale_factor/2, row * scale_factor + scale_factor / 2) + anchor_grid[width_idx, ratio_idx, row, col] = get_box( + width, ratio, anchor_point) + + return anchor_grid + + +def get_box(width: float, ratio: float, anchor_point: tuple[float, float]) -> np.ndarray: + box = np.empty(4) + box[0] = anchor_point[0] - (width / 2) + box[1] = anchor_point[1] - (width * ratio / 2) + box[2] = anchor_point[0] + (width / 2) + box[3] = anchor_point[1] + (width * ratio / 2) + return box diff --git a/mmp/a4/label_grid.py b/mmp/a4/label_grid.py index 3b61abd..e29e283 100644 --- a/mmp/a4/label_grid.py +++ b/mmp/a4/label_grid.py @@ -5,10 +5,34 @@ from ..a3.annotation import AnnotationRect def iou(rect1: AnnotationRect, rect2: AnnotationRect) -> float: - raise NotImplementedError() + x_left = max(rect1.x1, rect2.x1) + y_top = max(rect1.y1, rect2.y1) + x_right = min(rect1.x2, rect2.x2) + y_bottom = min(rect1.y2, rect2.y2) + + # Returns 0 if no overlap + if x_right <= x_left or y_bottom <= y_top: + return 0.0 + + intersection_area = (x_right - x_left) * (y_bottom - y_top) + + rect1_area = rect1.area() + rect2_area = rect2.area() + + union_area = rect1_area + rect2_area - intersection_area + + return intersection_area / union_area def get_label_grid( anchor_grid: np.ndarray, gts: Sequence[AnnotationRect], min_iou: float ) -> tuple[np.ndarray, ...]: - raise NotImplementedError() + label_grid = np.empty(anchor_grid.shape[:-1], dtype=bool) + for (width, ratio, row, col), item in np.ndenumerate(anchor_grid): + for gt in gts: + iou = iou(item, gt) + label_grid[width, ratio, row, col] = False + if (iou >= min_iou): + label_grid[width, ratio, row, col] = True + break + return label_grid