diff --git a/a1/main.py b/a1/main.py index e69de29..02462de 100644 --- a/a1/main.py +++ b/a1/main.py @@ -0,0 +1,35 @@ +import os +from PIL import Image +import torch +from torchvision import transforms, datasets, models +from torchvision.transforms import functional as F + + +def pad_to_square(img): + w, h = img.size + max_wh = max(w, h) + pad = ((max_wh - w) // 2, (max_wh - h) // 2) + padding = (pad[0], pad[1], max_wh - w - pad[0], max_wh - h - pad[1]) + return F.pad(img, padding, fill=0, padding_mode='constant') + + +def build_batch(paths, size=224): + preprocess = transforms.Compose([ + transforms.Lambda(pad_to_square), + transforms.Resize((size, size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + imgs = [] + for path in paths: + img = Image.open(path).convert('RGB') + img = preprocess(img) + imgs.append(img) + batch = torch.stack(imgs) + return batch + + +def get_model(): + model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) + return model