update template code

This commit is contained in:
phatakmr
2025-10-13 14:48:00 +02:00
parent c9d159fcc6
commit 8f637a4a0d
46 changed files with 2955 additions and 1 deletions

12
mmp/a4/anchor_grid.py Normal file
View File

@@ -0,0 +1,12 @@
from typing import Sequence
import numpy as np
def get_anchor_grid(
num_rows: int,
num_cols: int,
scale_factor: float,
anchor_widths: Sequence[float],
aspect_ratios: Sequence[float],
) -> np.ndarray:
raise NotImplementedError()

50
mmp/a4/dataset.py Normal file
View File

@@ -0,0 +1,50 @@
from typing import Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader
class MMP_Dataset(torch.utils.data.Dataset):
def __init__(
self,
path_to_data: str,
image_size: int,
anchor_grid: np.ndarray,
min_iou: float,
is_test: bool,
):
"""
@param anchor_grid: The anchor grid to be used for every image
@param min_iou: The minimum IoU that is required for an overlap for the label grid.
@param is_test: Whether this is the test set (True) or the validation/training set (False)
"""
raise NotImplementedError()
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
@return: 3-tuple of image tensor, label grid, and image (file-)number
"""
raise NotImplementedError()
def __len__(self) -> int:
raise NotImplementedError()
def get_dataloader(
path_to_data: str,
image_size: int,
batch_size: int,
num_workers: int,
anchor_grid: np.ndarray,
is_test: bool,
) -> DataLoader:
raise NotImplementedError()
def calculate_max_coverage(loader: DataLoader, min_iou: float) -> float:
"""
@param loader: A DataLoader object, generated with the get_dataloader function.
@param min_iou: Minimum IoU overlap that is required to count a ground truth box as covered.
@return: Ratio of how mamy ground truth boxes are covered by a label grid box. Must be a value between 0 and 1.
"""
raise NotImplementedError()

14
mmp/a4/label_grid.py Normal file
View File

@@ -0,0 +1,14 @@
from typing import Sequence
import numpy as np
from ..a3.annotation import AnnotationRect
def iou(rect1: AnnotationRect, rect2: AnnotationRect) -> float:
raise NotImplementedError()
def get_label_grid(
anchor_grid: np.ndarray, gts: Sequence[AnnotationRect], min_iou: float
) -> tuple[np.ndarray, ...]:
raise NotImplementedError()