formatting

This commit is contained in:
franksim
2025-11-07 11:20:08 +01:00
parent 8fc3559d6c
commit b159d76517
8 changed files with 119 additions and 82 deletions

View File

@@ -10,19 +10,25 @@ def get_anchor_grid(
aspect_ratios: Sequence[float],
) -> np.ndarray:
anchor_grid = np.empty(
[len(anchor_widths), len(aspect_ratios), num_rows, num_cols, 4], dtype=float)
for (width_idx, ratio_idx, row, col) in np.ndindex(anchor_grid.shape[:-1]):
[len(anchor_widths), len(aspect_ratios), num_rows, num_cols, 4], dtype=float
)
for width_idx, ratio_idx, row, col in np.ndindex(anchor_grid.shape[:-1]):
anchor_point = (
col * scale_factor + scale_factor / 2, row * scale_factor + scale_factor / 2)
col * scale_factor + scale_factor / 2,
row * scale_factor + scale_factor / 2,
)
width = anchor_widths[width_idx]
ratio = aspect_ratios[ratio_idx]
anchor_grid[width_idx, ratio_idx, row, col] = get_box(
width, ratio, anchor_point)
width, ratio, anchor_point
)
return anchor_grid
def get_box(width: float, ratio: float, anchor_point: tuple[float, float]) -> np.ndarray:
def get_box(
width: float, ratio: float, anchor_point: tuple[float, float]
) -> np.ndarray:
box = np.empty(4, dtype=float)
box[0] = anchor_point[0] - (width / 2)
box[1] = anchor_point[1] - (width * ratio / 2)

View File

@@ -28,7 +28,7 @@ def get_label_grid(
for gt in gts:
iou = iou(item, gt)
label_grid[width, ratio, row, col] = False
if (iou >= min_iou):
if iou >= min_iou:
label_grid[width, ratio, row, col] = True
break
return label_grid