adds anchor grid

This commit is contained in:
franksim
2025-11-06 15:14:11 +01:00
parent ef540f128e
commit 74da909b30
2 changed files with 47 additions and 3 deletions

View File

@@ -5,10 +5,34 @@ from ..a3.annotation import AnnotationRect
def iou(rect1: AnnotationRect, rect2: AnnotationRect) -> float:
raise NotImplementedError()
x_left = max(rect1.x1, rect2.x1)
y_top = max(rect1.y1, rect2.y1)
x_right = min(rect1.x2, rect2.x2)
y_bottom = min(rect1.y2, rect2.y2)
# Returns 0 if no overlap
if x_right <= x_left or y_bottom <= y_top:
return 0.0
intersection_area = (x_right - x_left) * (y_bottom - y_top)
rect1_area = rect1.area()
rect2_area = rect2.area()
union_area = rect1_area + rect2_area - intersection_area
return intersection_area / union_area
def get_label_grid(
anchor_grid: np.ndarray, gts: Sequence[AnnotationRect], min_iou: float
) -> tuple[np.ndarray, ...]:
raise NotImplementedError()
label_grid = np.empty(anchor_grid.shape[:-1], dtype=bool)
for (width, ratio, row, col), item in np.ndenumerate(anchor_grid):
for gt in gts:
iou = iou(item, gt)
label_grid[width, ratio, row, col] = False
if (iou >= min_iou):
label_grid[width, ratio, row, col] = True
break
return label_grid