Files
mmp_wise2526_franksim/mmp/a4/dataset.py

286 lines
9.0 KiB
Python
Raw Normal View History

2025-11-09 12:10:38 +01:00
import os
import re
from typing import Sequence, Tuple
2025-10-13 14:48:00 +02:00
import numpy as np
import torch
from torch.utils.data import DataLoader
2025-11-09 12:10:38 +01:00
from ..a3.annotation import read_groundtruth_file, AnnotationRect
2025-11-09 17:49:50 +01:00
from .label_grid import get_label_grid, iou
2025-11-09 12:10:38 +01:00
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from .anchor_grid import get_anchor_grid
from PIL import Image
from torchvision.transforms import transforms
from itertools import islice
2025-10-13 14:48:00 +02:00
class MMP_Dataset(torch.utils.data.Dataset):
def __init__(
self,
path_to_data: str,
image_size: int,
anchor_grid: np.ndarray,
min_iou: float,
is_test: bool,
):
"""
@param anchor_grid: The anchor grid to be used for every image
@param min_iou: The minimum IoU that is required for an overlap for the label grid.
@param is_test: Whether this is the test set (True) or the validation/training set (False)
"""
2025-11-09 12:10:38 +01:00
self.image_size = image_size
2025-11-11 10:52:27 +01:00
self.images: Sequence[Tuple[str, str | None]] = []
2025-11-09 12:10:38 +01:00
self.anchor_grid = anchor_grid
self.min_iou = min_iou
self.is_test = is_test
2025-11-09 17:49:50 +01:00
self.path_to_data = path_to_data
2025-11-09 12:10:38 +01:00
img_pattern = re.compile(r"^(\d+)\.jpg$")
files = set(os.listdir(path_to_data))
for fname in files:
match = img_pattern.match(fname)
if match:
img_file = os.path.join(path_to_data, fname)
2025-11-09 12:33:02 +01:00
if is_test:
self.images.append((img_file, None))
2025-11-11 10:52:27 +01:00
annotation_file = os.path.join(
path_to_data, f"{match.group(1)}.gt_data.txt"
)
self.images.append((img_file, annotation_file))
2025-11-09 12:10:38 +01:00
self.images.sort(
key=lambda x: int(re.match(r"(.*/)(\d+)(\.jpg)", x[0]).group(2))
)
2025-10-13 14:48:00 +02:00
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
@return: 3-tuple of image tensor, label grid, and image (file-)number
"""
2025-11-09 12:10:38 +01:00
img = Image.open(self.images[idx][0]).convert("RGB")
padding = self.__padding__(img)
transform = transforms.Compose(
[
transforms.Pad(padding, 0),
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
img_tensor = transform(img)
2025-11-09 12:33:02 +01:00
img_id = re.match(r".*(\/)([0-9]+)(\.[^\/]*$)", self.images[idx][0]).group(2)
if self.is_test:
return (img_tensor, torch.Tensor(), int(img_id))
2025-11-09 13:01:46 +01:00
2025-11-11 10:52:27 +01:00
annotations = [
2025-11-09 13:01:46 +01:00
annotation.scale(self.image_size / max(img.size[0], img.size[1]))
2025-11-11 10:52:27 +01:00
for annotation in read_groundtruth_file(self.images[idx][1])
]
2025-11-09 13:01:46 +01:00
2025-11-09 12:10:38 +01:00
label_grid = get_label_grid(
2025-11-11 10:52:27 +01:00
anchor_grid=self.anchor_grid, gts=annotations, min_iou=self.min_iou
2025-11-09 12:10:38 +01:00
)
return (img_tensor, label_grid, int(img_id))
2025-10-13 14:48:00 +02:00
def __len__(self) -> int:
2025-11-09 12:10:38 +01:00
return len(self.images)
def __padding__(self, img) -> Tuple[int, int, int, int]:
w, h = img.size
size = max(w, h)
right_pad = size - w
bottom_pad = size - h
return (0, 0, right_pad, bottom_pad)
2025-10-13 14:48:00 +02:00
def get_dataloader(
path_to_data: str,
image_size: int,
batch_size: int,
num_workers: int,
anchor_grid: np.ndarray,
is_test: bool,
2025-11-09 12:24:28 +01:00
is_train: bool = False,
2025-10-13 14:48:00 +02:00
) -> DataLoader:
2025-11-09 12:10:38 +01:00
dataset = MMP_Dataset(
path_to_data=path_to_data,
image_size=image_size,
is_test=is_test,
anchor_grid=anchor_grid,
min_iou=0.7,
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=is_train,
num_workers=num_workers,
pin_memory=True,
)
return dataloader
2025-10-13 14:48:00 +02:00
2025-11-11 11:08:07 +01:00
def calculate_coverage(loader, min_iou):
2025-10-13 14:48:00 +02:00
"""
2025-11-11 10:52:27 +01:00
@param loader: DataLoader object.
@param min_iou: Minimum IoU overlap to count a ground truth box as covered.
@return: Ratio of how many ground truth boxes are covered by a label grid box. Value between 0 and 1.
2025-10-13 14:48:00 +02:00
"""
2025-11-09 17:49:50 +01:00
total_boxes = 0
covered_boxes = 0
2025-11-11 10:52:27 +01:00
dataset = loader.dataset
anchor_grid = dataset.anchor_grid # Shape: (H, W, 4)
2025-11-09 17:49:50 +01:00
2025-11-11 10:52:27 +01:00
# Reshape anchor grid to (N, 4)
anchors = anchor_grid.reshape(-1, 4)
for img, _, img_id in loader:
2025-11-09 17:49:50 +01:00
for batch_index in range(len(img)):
gts_file = os.path.join(
dataset.path_to_data,
f"{str(img_id[batch_index].item()).zfill(8)}.gt_data.txt",
)
2025-11-11 10:52:27 +01:00
# Load and scale ground truth boxes if necessary
with Image.open(
os.path.join(
dataset.path_to_data,
f"{str(img_id[batch_index].item()).zfill(8)}.jpg",
)
) as original_image:
original_w, original_h = original_image.size
# Assume square resize for model, get transform size from img tensor
transformed_size = img[batch_index].shape[-1]
scale = transformed_size / max(original_w, original_h)
annotations = [
annotation.scale(scale)
for annotation in read_groundtruth_file(gts_file)
]
gt_boxes = np.stack(
[np.array(a) for a in annotations], axis=0
) # shape (M, 4)
total_boxes += len(gt_boxes)
2025-11-09 17:49:50 +01:00
2025-11-11 10:52:27 +01:00
# Vectorized IoU calculation: (M, N)
ious = compute_ious_vectorized(gt_boxes, anchors) # shape (M, N)
# Count ground truths for which any anchor box matches min_iou
covered = (ious >= min_iou).any(axis=1).sum()
covered_boxes += covered
return covered_boxes / total_boxes if total_boxes > 0 else 0.0
def compute_ious_vectorized(boxes1, boxes2):
"""
Compute the IoU matrix between each box in boxes1 and each box in boxes2.
boxes1: (M, 4), boxes2: (N, 4) -- format [x1, y1, x2, y2]
Returns: (M, N) IoU
"""
M, N = boxes1.shape[0], boxes2.shape[0]
# Expand to (M, N, 4)
boxes1 = boxes1[:, None, :] # (M, 1, 4)
boxes2 = boxes2[None, :, :] # (1, N, 4)
# Intersection box
inter_x1 = np.maximum(boxes1[..., 0], boxes2[..., 0])
inter_y1 = np.maximum(boxes1[..., 1], boxes2[..., 1])
inter_x2 = np.minimum(boxes1[..., 2], boxes2[..., 2])
inter_y2 = np.minimum(boxes1[..., 3], boxes2[..., 3])
inter_w = np.clip(inter_x2 - inter_x1, 0, None)
inter_h = np.clip(inter_y2 - inter_y1, 0, None)
inter_area = inter_w * inter_h
area1 = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1])
area2 = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1])
union_area = area1 + area2 - inter_area
return inter_area / (union_area + 1e-6)
2025-11-09 17:49:50 +01:00
def draw_image_tensor_with_annotations(
img: torch.Tensor,
annotations: Sequence["AnnotationRect"] | None,
output_file: str,
2025-11-09 12:10:38 +01:00
):
# Convert tensor to numpy, permute dimensions
2025-11-09 17:49:50 +01:00
img_np = img.permute(1, 2, 0).numpy()
img_np = np.clip(img_np, 0, 1)
2025-11-09 12:10:38 +01:00
fig, ax = plt.subplots(1)
ax.imshow(img_np)
for rect in annotations:
x1, y1, x2, y2 = rect.x1, rect.y1, rect.x2, rect.y2
w = x2 - x1
h = y2 - y1
patch = patches.Rectangle(
(x1, y1), w, h, linewidth=2, edgecolor="red", facecolor="none"
)
ax.add_patch(patch)
plt.axis("off")
plt.tight_layout(pad=0)
plt.savefig(output_file, bbox_inches="tight", pad_inches=0)
plt.close(fig)
2025-11-09 17:49:50 +01:00
def denormalize_image_tensor(
img: torch.Tensor,
mean=torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1),
std=torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1),
) -> torch.Tensor:
img_denormalized = img * std + mean
return img_denormalized
def draw_positive_boxes(
2025-11-09 12:10:38 +01:00
img_tensor: torch.Tensor,
label_grid: np.ndarray,
img_id: torch.Tensor,
anchor_grid: np.ndarray,
):
annotations = [
AnnotationRect.fromarray(anchor_grid[idx])
for idx in np.ndindex(anchor_grid.shape[:-1])
if label_grid[idx]
]
2025-11-09 17:49:50 +01:00
draw_image_tensor_with_annotations(
2025-11-09 12:10:38 +01:00
img_tensor,
annotations=annotations,
2025-11-09 17:49:50 +01:00
output_file=f"mmp/a4/.output/{img_id}_transformed.png",
2025-11-09 12:10:38 +01:00
)
def main():
anchor_grid = get_anchor_grid(
2025-11-11 11:08:07 +01:00
anchor_widths=[8, 16, 32, 64, 96, 128, 160, 192],
2025-11-09 12:10:38 +01:00
aspect_ratios=[1 / 3, 1 / 2, 3 / 5, 2 / 3, 3 / 4, 1, 4 / 3, 5 / 3, 2, 2.5, 3],
2025-11-11 10:52:27 +01:00
num_rows=28,
num_cols=28,
scale_factor=8,
2025-11-09 12:10:38 +01:00
)
dataloader = get_dataloader(
2025-11-11 10:52:27 +01:00
num_workers=9,
2025-11-09 12:10:38 +01:00
is_train=True,
is_test=False,
batch_size=8,
image_size=224,
path_to_data=".data/mmp-public-3.2/train",
anchor_grid=anchor_grid,
)
2025-11-11 11:08:07 +01:00
# print(calculate_coverage(dataloader, 0.7))
2025-11-11 10:52:27 +01:00
2025-11-09 12:10:38 +01:00
for img, label, img_id in islice(dataloader, 12):
2025-11-09 17:49:50 +01:00
draw_positive_boxes(
img_tensor=denormalize_image_tensor(img=img[5]),
2025-11-09 12:10:38 +01:00
label_grid=label[5],
img_id=img_id[5],
anchor_grid=anchor_grid,
)
if __name__ == "__main__":
main()