import os from PIL import Image import torch from torchvision import transforms, datasets, models from torchvision.transforms import functional as F def pad_to_square(img): w, h = img.size max_wh = max(w, h) pad = ((max_wh - w) // 2, (max_wh - h) // 2) padding = (pad[0], pad[1], max_wh - w - pad[0], max_wh - h - pad[1]) return F.pad(img, padding, fill=0, padding_mode='constant') def build_batch(paths, size=224): preprocess = transforms.Compose([ transforms.Lambda(pad_to_square), transforms.Resize((size, size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) imgs = [] for path in paths: img = Image.open(path).convert('RGB') img = preprocess(img) imgs.append(img) batch = torch.stack(imgs) return batch def get_model(): model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) return model