Files
mmp_wise2526_franksim/mmp/a1/main.py

84 lines
2.3 KiB
Python
Raw Normal View History

2025-10-13 14:48:00 +02:00
from typing import Sequence
import torch
2025-10-14 07:45:19 +00:00
import torchvision
2025-10-13 13:04:11 +00:00
from torchvision.transforms import functional as F
from torchvision import transforms, models
from PIL import Image
def pad_to_square(img):
w, h = img.size
max_wh = max(w, h)
pad = ((max_wh - w) // 2, (max_wh - h) // 2)
padding = (pad[0], pad[1], max_wh - w - pad[0], max_wh - h - pad[1])
return F.pad(img, padding, fill=0, padding_mode='constant')
2025-10-13 14:48:00 +02:00
def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor:
"""Exercise 1.1
@param paths: A sequence (e.g. list) of strings, each specifying the location of an image file.
@param transform: One or multiple image transformations for augmenting the batch images.
@return: Returns one single tensor that contains every image.
"""
2025-10-13 13:04:11 +00:00
preprocess = transforms.Compose([
transforms.Lambda(pad_to_square),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
imgs = []
for path in paths:
img = Image.open(path).convert('RGB')
img = preprocess(img)
imgs.append(img)
batch = torch.stack(imgs)
return batch
2025-10-13 14:48:00 +02:00
def get_model() -> torch.nn.Module:
"""Exercise 1.2
@return: Returns a neural network, initialised with pretrained weights.
"""
2025-10-13 13:04:11 +00:00
model = models.resnet18(
weights=models.ResNet18_Weights.DEFAULT)
return model
2025-10-13 14:48:00 +02:00
def main():
"""Exercise 1.3
Put all your code for exercise 1.3 here.
"""
2025-10-14 07:45:19 +00:00
paths = [
"./images/golden retriever.jpg",
"./images/koala.jpg",
"./images/pacifier.jpg",
"./images/rubber duck sculpture.jpg",
"./images/rubber ducks.jpg",
"./images/shoehorn.jpg",
"./images/zoo.jpg",
]
batch = build_batch(paths)
model = get_model()
with torch.no_grad():
outputs = model(batch)
max_scores, preds = outputs.max(dim=1)
class_names = torchvision.models.ResNet18_Weights.DEFAULT.meta["categories"]
for i, (p, s) in enumerate(zip(preds, max_scores)):
print(f"Image: {paths[i]}")
print(f" Model output score: {s.item():.4f}")
print(f" Predicted class: {class_names[p.item()]}")
print()
2025-10-13 14:48:00 +02:00
if __name__ == "__main__":
main()