import os import re from PIL import Image from typing import Tuple import torch from torch.utils.data import DataLoader from .annotation import read_groundtruth_file from torchvision import transforms class MMP_Dataset(torch.utils.data.Dataset): """Exercise 3.2""" def __init__(self, path_to_data: str, image_size: int): """ @param path_to_data: Path to the folder that contains the images and annotation files, e.g. dataset_mmp/train @param image_size: Desired image size that this dataset should return """ self.image_size = image_size img_pattern = re.compile(r"^(\d+)\.jpg$") files = set(os.listdir(path_to_data)) self.images = [] for fname in files: match = img_pattern.match(fname) if match: img_file = os.path.join(path_to_data, fname) annotations = read_groundtruth_file( os.path.join(path_to_data, f"{match.group(1)}.gt_data.txt") ) self.images.append((img_file, annotations)) self.images.sort( key=lambda x: int(re.match(r"(.*/)(\d+)(\.jpg)", x[0]).group(2)) ) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: """ @return: Tuple of image tensor and label. The label is 0 if there is one person and 1 if there a multiple people. """ img = Image.open(self.images[idx][0]).convert("RGB") padding = self.__padding__(img) transform = transforms.Compose( [ transforms.Pad(padding, 0), transforms.Resize((self.image_size, self.image_size)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) return (transform(img), 1 if len(self.images[idx][1]) > 1 else 0) def __padding__(self, img) -> Tuple[int, int, int, int]: w, h = img.size size = max(w, h) right_pad = size - w bottom_pad = size - h return (0, 0, right_pad, bottom_pad) def __len__(self) -> int: return len(self.images) def get_dataloader( path_to_data: str, image_size: int, batch_size: int, num_workers: int, is_train: bool = True, ) -> DataLoader: """Exercise 3.2d""" path = ( os.path.join(path_to_data, "train") if is_train else os.path.join(path_to_data, "val") ) dataset = MMP_Dataset(path_to_data=path, image_size=image_size) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=is_train, num_workers=num_workers, pin_memory=True, ) return dataloader