formatting
This commit is contained in:
@@ -17,7 +17,7 @@ class MMP_Dataset(torch.utils.data.Dataset):
|
||||
@param image_size: Desired image size that this dataset should return
|
||||
"""
|
||||
self.image_size = image_size
|
||||
img_pattern = re.compile(r'^(\d+)\.jpg$')
|
||||
img_pattern = re.compile(r"^(\d+)\.jpg$")
|
||||
files = set(os.listdir(path_to_data))
|
||||
self.images = []
|
||||
|
||||
@@ -25,12 +25,14 @@ class MMP_Dataset(torch.utils.data.Dataset):
|
||||
match = img_pattern.match(fname)
|
||||
if match:
|
||||
img_file = os.path.join(path_to_data, fname)
|
||||
annotations = read_groundtruth_file(os.path.join(
|
||||
path_to_data, f"{match.group(1)}.gt_data.txt"))
|
||||
annotations = read_groundtruth_file(
|
||||
os.path.join(path_to_data, f"{match.group(1)}.gt_data.txt")
|
||||
)
|
||||
self.images.append((img_file, annotations))
|
||||
|
||||
self.images.sort(key=lambda x: int(
|
||||
re.match(r"(.*/)(\d+)(\.jpg)", x[0]).group(2)))
|
||||
self.images.sort(
|
||||
key=lambda x: int(re.match(r"(.*/)(\d+)(\.jpg)", x[0]).group(2))
|
||||
)
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
||||
"""
|
||||
@@ -38,15 +40,16 @@ class MMP_Dataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
img = Image.open(self.images[idx][0]).convert("RGB")
|
||||
padding = self.__padding__(img)
|
||||
transform = transforms.Compose([
|
||||
transforms.Pad(padding, 0),
|
||||
transforms.Resize((self.image_size, self.image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
])
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Pad(padding, 0),
|
||||
transforms.Resize((self.image_size, self.image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
]
|
||||
)
|
||||
return (transform(img), 1 if len(self.images[idx][1]) > 1 else 0)
|
||||
|
||||
def __padding__(self, img) -> Tuple[int, int, int, int]:
|
||||
@@ -61,16 +64,24 @@ class MMP_Dataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
def get_dataloader(
|
||||
path_to_data: str, image_size: int, batch_size: int, num_workers: int, is_train: bool = True
|
||||
path_to_data: str,
|
||||
image_size: int,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
is_train: bool = True,
|
||||
) -> DataLoader:
|
||||
"""Exercise 3.2d"""
|
||||
path = os.path.join(path_to_data, "train") if is_train else os.path.join(
|
||||
path_to_data, "val")
|
||||
path = (
|
||||
os.path.join(path_to_data, "train")
|
||||
if is_train
|
||||
else os.path.join(path_to_data, "val")
|
||||
)
|
||||
dataset = MMP_Dataset(path_to_data=path, image_size=image_size)
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size,
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=is_train,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True
|
||||
pin_memory=True,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
Reference in New Issue
Block a user