diff --git a/mmp/a4/label_grid.py b/mmp/a4/label_grid.py index c16e182..ebc3d12 100644 --- a/mmp/a4/label_grid.py +++ b/mmp/a4/label_grid.py @@ -26,9 +26,9 @@ def get_label_grid( label_grid = np.empty(anchor_grid.shape[:-1], dtype=bool) for (width, ratio, row, col), item in np.ndenumerate(anchor_grid): for gt in gts: - iou = iou(item, gt) + calculated_iou = iou(item, gt) label_grid[width, ratio, row, col] = False - if iou >= min_iou: + if calculated_iou >= min_iou: label_grid[width, ratio, row, col] = True break return label_grid