impoves code

This commit is contained in:
franksim
2025-11-07 11:15:51 +01:00
parent 74da909b30
commit 8fc3559d6c
2 changed files with 10 additions and 15 deletions

View File

@@ -10,21 +10,20 @@ def get_anchor_grid(
aspect_ratios: Sequence[float], aspect_ratios: Sequence[float],
) -> np.ndarray: ) -> np.ndarray:
anchor_grid = np.empty( anchor_grid = np.empty(
[len(width), len(ratio), num_rows, num_cols, 4], dtype=int) [len(anchor_widths), len(aspect_ratios), num_rows, num_cols, 4], dtype=float)
for width_idx, width in enumerate(anchor_widths): for (width_idx, ratio_idx, row, col) in np.ndindex(anchor_grid.shape[:-1]):
for ratio_idx, ratio in enumerate(aspect_ratios): anchor_point = (
for row in range(num_rows): col * scale_factor + scale_factor / 2, row * scale_factor + scale_factor / 2)
for col in range(num_cols): width = anchor_widths[width_idx]
anchor_point = ( ratio = aspect_ratios[ratio_idx]
col * scale_factor + scale_factor/2, row * scale_factor + scale_factor / 2) anchor_grid[width_idx, ratio_idx, row, col] = get_box(
anchor_grid[width_idx, ratio_idx, row, col] = get_box( width, ratio, anchor_point)
width, ratio, anchor_point)
return anchor_grid return anchor_grid
def get_box(width: float, ratio: float, anchor_point: tuple[float, float]) -> np.ndarray: def get_box(width: float, ratio: float, anchor_point: tuple[float, float]) -> np.ndarray:
box = np.empty(4) box = np.empty(4, dtype=float)
box[0] = anchor_point[0] - (width / 2) box[0] = anchor_point[0] - (width / 2)
box[1] = anchor_point[1] - (width * ratio / 2) box[1] = anchor_point[1] - (width * ratio / 2)
box[2] = anchor_point[0] + (width / 2) box[2] = anchor_point[0] + (width / 2)

View File

@@ -10,11 +10,7 @@ def iou(rect1: AnnotationRect, rect2: AnnotationRect) -> float:
x_right = min(rect1.x2, rect2.x2) x_right = min(rect1.x2, rect2.x2)
y_bottom = min(rect1.y2, rect2.y2) y_bottom = min(rect1.y2, rect2.y2)
# Returns 0 if no overlap intersection_area = max(x_right - x_left, 0) * max(y_bottom - y_top, 0)
if x_right <= x_left or y_bottom <= y_top:
return 0.0
intersection_area = (x_right - x_left) * (y_bottom - y_top)
rect1_area = rect1.area() rect1_area = rect1.area()
rect2_area = rect2.area() rect2_area = rect2.area()