Files
mmp_wise2526_franksim/mmp/a3/dataset.py

88 lines
2.7 KiB
Python
Raw Permalink Normal View History

2025-10-28 16:03:53 +00:00
import os
import re
from PIL import Image
2025-10-13 14:48:00 +02:00
from typing import Tuple
import torch
from torch.utils.data import DataLoader
2025-10-31 14:02:28 +00:00
from .annotation import read_groundtruth_file
2025-10-28 16:03:53 +00:00
from torchvision import transforms
2025-10-13 14:48:00 +02:00
class MMP_Dataset(torch.utils.data.Dataset):
"""Exercise 3.2"""
def __init__(self, path_to_data: str, image_size: int):
"""
@param path_to_data: Path to the folder that contains the images and annotation files, e.g. dataset_mmp/train
@param image_size: Desired image size that this dataset should return
"""
2025-10-28 16:03:53 +00:00
self.image_size = image_size
2025-11-07 11:20:08 +01:00
img_pattern = re.compile(r"^(\d+)\.jpg$")
2025-10-28 16:03:53 +00:00
files = set(os.listdir(path_to_data))
self.images = []
for fname in files:
match = img_pattern.match(fname)
if match:
img_file = os.path.join(path_to_data, fname)
2025-11-07 11:20:08 +01:00
annotations = read_groundtruth_file(
os.path.join(path_to_data, f"{match.group(1)}.gt_data.txt")
)
2025-10-28 16:03:53 +00:00
self.images.append((img_file, annotations))
2025-11-07 11:20:08 +01:00
self.images.sort(
key=lambda x: int(re.match(r"(.*/)(\d+)(\.jpg)", x[0]).group(2))
)
2025-10-13 14:48:00 +02:00
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
"""
@return: Tuple of image tensor and label. The label is 0 if there is one person and 1 if there a multiple people.
"""
2025-10-28 16:03:53 +00:00
img = Image.open(self.images[idx][0]).convert("RGB")
padding = self.__padding__(img)
2025-11-07 11:20:08 +01:00
transform = transforms.Compose(
[
transforms.Pad(padding, 0),
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
2025-10-28 16:03:53 +00:00
return (transform(img), 1 if len(self.images[idx][1]) > 1 else 0)
def __padding__(self, img) -> Tuple[int, int, int, int]:
w, h = img.size
size = max(w, h)
right_pad = size - w
bottom_pad = size - h
return (0, 0, right_pad, bottom_pad)
2025-10-13 14:48:00 +02:00
def __len__(self) -> int:
2025-10-28 16:03:53 +00:00
return len(self.images)
2025-10-13 14:48:00 +02:00
def get_dataloader(
2025-11-07 11:20:08 +01:00
path_to_data: str,
image_size: int,
batch_size: int,
num_workers: int,
is_train: bool = True,
2025-10-13 14:48:00 +02:00
) -> DataLoader:
"""Exercise 3.2d"""
2025-11-07 11:20:08 +01:00
path = (
os.path.join(path_to_data, "train")
if is_train
else os.path.join(path_to_data, "val")
)
2025-10-28 16:03:53 +00:00
dataset = MMP_Dataset(path_to_data=path, image_size=image_size)
dataloader = DataLoader(
2025-11-07 11:20:08 +01:00
dataset,
batch_size=batch_size,
2025-10-28 16:03:53 +00:00
shuffle=is_train,
num_workers=num_workers,
2025-11-07 11:20:08 +01:00
pin_memory=True,
2025-10-28 16:03:53 +00:00
)
return dataloader