diff --git a/a1/main.py b/a1/main.py deleted file mode 100644 index 02462de..0000000 --- a/a1/main.py +++ /dev/null @@ -1,35 +0,0 @@ -import os -from PIL import Image -import torch -from torchvision import transforms, datasets, models -from torchvision.transforms import functional as F - - -def pad_to_square(img): - w, h = img.size - max_wh = max(w, h) - pad = ((max_wh - w) // 2, (max_wh - h) // 2) - padding = (pad[0], pad[1], max_wh - w - pad[0], max_wh - h - pad[1]) - return F.pad(img, padding, fill=0, padding_mode='constant') - - -def build_batch(paths, size=224): - preprocess = transforms.Compose([ - transforms.Lambda(pad_to_square), - transforms.Resize((size, size)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) - imgs = [] - for path in paths: - img = Image.open(path).convert('RGB') - img = preprocess(img) - imgs.append(img) - batch = torch.stack(imgs) - return batch - - -def get_model(): - model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) - return model diff --git a/a1/requirements.txt b/a1/requirements.txt deleted file mode 100644 index ea3956e..0000000 --- a/a1/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -torch -torchvision ---index-url https://download.pytorch.org/whl/cu129 \ No newline at end of file diff --git a/mmp/a1/main.py b/mmp/a1/main.py index c1d0eda..fd807bf 100644 --- a/mmp/a1/main.py +++ b/mmp/a1/main.py @@ -1,5 +1,16 @@ from typing import Sequence import torch +from torchvision.transforms import functional as F +from torchvision import transforms, models +from PIL import Image + + +def pad_to_square(img): + w, h = img.size + max_wh = max(w, h) + pad = ((max_wh - w) // 2, (max_wh - h) // 2) + padding = (pad[0], pad[1], max_wh - w - pad[0], max_wh - h - pad[1]) + return F.pad(img, padding, fill=0, padding_mode='constant') def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor: @@ -9,7 +20,20 @@ def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor: @param transform: One or multiple image transformations for augmenting the batch images. @return: Returns one single tensor that contains every image. """ - raise NotImplementedError() + preprocess = transforms.Compose([ + transforms.Lambda(pad_to_square), + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + imgs = [] + for path in paths: + img = Image.open(path).convert('RGB') + img = preprocess(img) + imgs.append(img) + batch = torch.stack(imgs) + return batch def get_model() -> torch.nn.Module: @@ -17,7 +41,9 @@ def get_model() -> torch.nn.Module: @return: Returns a neural network, initialised with pretrained weights. """ - raise NotImplementedError() + model = models.resnet18( + weights=models.ResNet18_Weights.DEFAULT) + return model def main():