From c88ba5f1eb67c0010a0037031cd7e9bd7f3eb0f7 Mon Sep 17 00:00:00 2001 From: franksim Date: Mon, 13 Oct 2025 12:45:39 +0000 Subject: [PATCH] assignment-a1: adds build_batch and get_model function --- a1/main.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/a1/main.py b/a1/main.py index e69de29..02462de 100644 --- a/a1/main.py +++ b/a1/main.py @@ -0,0 +1,35 @@ +import os +from PIL import Image +import torch +from torchvision import transforms, datasets, models +from torchvision.transforms import functional as F + + +def pad_to_square(img): + w, h = img.size + max_wh = max(w, h) + pad = ((max_wh - w) // 2, (max_wh - h) // 2) + padding = (pad[0], pad[1], max_wh - w - pad[0], max_wh - h - pad[1]) + return F.pad(img, padding, fill=0, padding_mode='constant') + + +def build_batch(paths, size=224): + preprocess = transforms.Compose([ + transforms.Lambda(pad_to_square), + transforms.Resize((size, size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + imgs = [] + for path in paths: + img = Image.open(path).convert('RGB') + img = preprocess(img) + imgs.append(img) + batch = torch.stack(imgs) + return batch + + +def get_model(): + model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) + return model