From 89faa2649bc5c4c2b7df4d500b25e59ca605d971 Mon Sep 17 00:00:00 2001 From: franksim Date: Tue, 14 Oct 2025 08:53:19 +0000 Subject: [PATCH] assignment-a1: adds different batches for different transforms --- mmp/a1/main.py | 55 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/mmp/a1/main.py b/mmp/a1/main.py index e134645..c9eb169 100644 --- a/mmp/a1/main.py +++ b/mmp/a1/main.py @@ -2,7 +2,7 @@ from typing import Sequence import torch import torchvision from torchvision.transforms import functional as F -from torchvision import transforms, models +from torchvision import models, transforms from PIL import Image @@ -14,7 +14,7 @@ def pad_to_square(img): return F.pad(img, padding, fill=0, padding_mode='constant') -def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor: +def build_batch(paths: Sequence[str], size=(224, 224), additional_transforms=[]) -> torch.Tensor: """Exercise 1.1 @param paths: A sequence (e.g. list) of strings, each specifying the location of an image file. @@ -23,12 +23,15 @@ def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor: """ preprocess = transforms.Compose([ transforms.Lambda(pad_to_square), - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) + transforms.Resize(size)] + + additional_transforms + + + [transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) imgs = [] + for path in paths: img = Image.open(path).convert('RGB') img = preprocess(img) @@ -47,6 +50,21 @@ def get_model() -> torch.nn.Module: return model +def forward_pass(paths, batch, model): + with torch.no_grad(): + outputs = model(batch) + + max_scores, preds = outputs.max(dim=1) + + class_names = torchvision.models.ResNet18_Weights.DEFAULT.meta["categories"] + + for i, (p, s) in enumerate(zip(preds, max_scores)): + print(f"Image: {paths[i]}") + print(f"Model output score: {s.item():.4f}") + print(f"Predicted class: {class_names[p.item()]}") + print() + + def main(): """Exercise 1.3 @@ -62,22 +80,19 @@ def main(): "./images/shoehorn.jpg", "./images/zoo.jpg", ] - batch = build_batch(paths) + batch_a = build_batch(paths) model = get_model() + print("Batch A:") + forward_pass(paths, batch_a, model) - with torch.no_grad(): - outputs = model(batch) - - max_scores, preds = outputs.max(dim=1) - - class_names = torchvision.models.ResNet18_Weights.DEFAULT.meta["categories"] - - for i, (p, s) in enumerate(zip(preds, max_scores)): - print(f"Image: {paths[i]}") - print(f" Model output score: {s.item():.4f}") - print(f" Predicted class: {class_names[p.item()]}") - print() + print("Batch B:") + batch_b = build_batch(paths, (400, 400)) + forward_pass(paths, batch_b, model) + print("Batch C:") + batch_c = build_batch(paths, additional_transforms=[ + transforms.RandomVerticalFlip(1)]) + forward_pass(paths, batch_c, model) if __name__ == "__main__": main()