import os import re from PIL import Image from typing import Tuple import torch from torch.utils.data import DataLoader from a3.annotation import read_groundtruth_file from torchvision import transforms class MMP_Dataset(torch.utils.data.Dataset): """Exercise 3.2""" def __init__(self, path_to_data: str, image_size: int): """ @param path_to_data: Path to the folder that contains the images and annotation files, e.g. dataset_mmp/train @param image_size: Desired image size that this dataset should return """ self.image_size = image_size img_pattern = re.compile(r'^(\d+)\.jpg$') files = set(os.listdir(path_to_data)) self.images = [] for fname in files: match = img_pattern.match(fname) if match: img_file = os.path.join(path_to_data, fname) annotations = read_groundtruth_file(os.path.join( path_to_data, f"{match.group(1)}.gt_data.txt")) self.images.append((img_file, annotations)) self.images.sort(key=lambda x: int( re.match(r"(.*/)(\d+)(\.jpg)", x[0]).group(2))) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: """ @return: Tuple of image tensor and label. The label is 0 if there is one person and 1 if there a multiple people. """ img = Image.open(self.images[idx][0]).convert("RGB") padding = self.__padding__(img) transform = transforms.Compose([ transforms.Pad(padding, 0), transforms.Resize((self.image_size, self.image_size)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) return (transform(img), 1 if len(self.images[idx][1]) > 1 else 0) def __padding__(self, img) -> Tuple[int, int, int, int]: w, h = img.size size = max(w, h) right_pad = size - w bottom_pad = size - h return (0, 0, right_pad, bottom_pad) def __len__(self) -> int: return len(self.images) def get_dataloader( path_to_data: str, image_size: int, batch_size: int, num_workers: int, is_train: bool = True ) -> DataLoader: """Exercise 3.2d""" path = os.path.join(path_to_data, "train") if is_train else os.path.join( path_to_data, "val") dataset = MMP_Dataset(path_to_data=path, image_size=image_size) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=is_train, num_workers=num_workers, pin_memory=True ) return dataloader