2025-10-13 14:48:00 +02:00
|
|
|
from typing import Sequence
|
|
|
|
|
import torch
|
2025-10-14 07:45:19 +00:00
|
|
|
import torchvision
|
2025-10-13 13:04:11 +00:00
|
|
|
from torchvision.transforms import functional as F
|
2025-10-14 08:53:19 +00:00
|
|
|
from torchvision import models, transforms
|
2025-10-13 13:04:11 +00:00
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad_to_square(img):
|
|
|
|
|
w, h = img.size
|
|
|
|
|
max_wh = max(w, h)
|
|
|
|
|
pad = ((max_wh - w) // 2, (max_wh - h) // 2)
|
|
|
|
|
padding = (pad[0], pad[1], max_wh - w - pad[0], max_wh - h - pad[1])
|
2025-11-07 11:20:08 +01:00
|
|
|
return F.pad(img, padding, fill=0, padding_mode="constant")
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
|
2025-10-16 14:39:25 +00:00
|
|
|
def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor:
|
2025-10-13 14:48:00 +02:00
|
|
|
"""Exercise 1.1
|
|
|
|
|
|
|
|
|
|
@param paths: A sequence (e.g. list) of strings, each specifying the location of an image file.
|
|
|
|
|
@param transform: One or multiple image transformations for augmenting the batch images.
|
|
|
|
|
@return: Returns one single tensor that contains every image.
|
|
|
|
|
"""
|
2025-11-07 11:20:08 +01:00
|
|
|
preprocess = transforms.Compose(
|
|
|
|
|
[
|
|
|
|
|
transforms.Lambda(pad_to_square),
|
|
|
|
|
transforms.Resize((224, 224)),
|
|
|
|
|
*([transform] if transform is not None else []),
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
|
]
|
2025-10-16 14:39:25 +00:00
|
|
|
)
|
2025-10-13 13:04:11 +00:00
|
|
|
imgs = []
|
2025-10-16 14:39:25 +00:00
|
|
|
|
2025-10-13 13:04:11 +00:00
|
|
|
for path in paths:
|
2025-11-07 11:20:08 +01:00
|
|
|
img = Image.open(path).convert("RGB")
|
2025-10-13 13:04:11 +00:00
|
|
|
img = preprocess(img)
|
|
|
|
|
imgs.append(img)
|
|
|
|
|
batch = torch.stack(imgs)
|
|
|
|
|
return batch
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model() -> torch.nn.Module:
|
|
|
|
|
"""Exercise 1.2
|
|
|
|
|
|
|
|
|
|
@return: Returns a neural network, initialised with pretrained weights.
|
|
|
|
|
"""
|
2025-11-07 11:20:08 +01:00
|
|
|
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
|
2025-10-13 13:04:11 +00:00
|
|
|
return model
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
|
2025-10-14 08:53:19 +00:00
|
|
|
def forward_pass(paths, batch, model):
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
outputs = model(batch)
|
|
|
|
|
|
|
|
|
|
max_scores, preds = outputs.max(dim=1)
|
|
|
|
|
|
|
|
|
|
class_names = torchvision.models.ResNet18_Weights.DEFAULT.meta["categories"]
|
|
|
|
|
|
|
|
|
|
for i, (p, s) in enumerate(zip(preds, max_scores)):
|
|
|
|
|
print(f"Image: {paths[i]}")
|
|
|
|
|
print(f"Model output score: {s.item():.4f}")
|
|
|
|
|
print(f"Predicted class: {class_names[p.item()]}")
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
|
2025-10-13 14:48:00 +02:00
|
|
|
def main():
|
|
|
|
|
"""Exercise 1.3
|
|
|
|
|
|
|
|
|
|
Put all your code for exercise 1.3 here.
|
|
|
|
|
"""
|
2025-10-14 07:45:19 +00:00
|
|
|
|
|
|
|
|
paths = [
|
|
|
|
|
"./images/golden retriever.jpg",
|
|
|
|
|
"./images/koala.jpg",
|
|
|
|
|
"./images/pacifier.jpg",
|
|
|
|
|
"./images/rubber duck sculpture.jpg",
|
|
|
|
|
"./images/rubber ducks.jpg",
|
|
|
|
|
"./images/shoehorn.jpg",
|
|
|
|
|
"./images/zoo.jpg",
|
|
|
|
|
]
|
2025-10-14 08:53:19 +00:00
|
|
|
batch_a = build_batch(paths)
|
2025-10-14 07:45:19 +00:00
|
|
|
model = get_model()
|
2025-10-14 08:53:19 +00:00
|
|
|
print("Batch A:")
|
|
|
|
|
forward_pass(paths, batch_a, model)
|
2025-10-14 07:45:19 +00:00
|
|
|
|
2025-10-14 08:53:19 +00:00
|
|
|
print("Batch B:")
|
2025-10-16 14:39:25 +00:00
|
|
|
batch_b = build_batch(paths, transforms.Resize((100, 100)))
|
2025-10-14 08:53:19 +00:00
|
|
|
forward_pass(paths, batch_b, model)
|
2025-10-13 14:48:00 +02:00
|
|
|
|
2025-10-14 08:53:19 +00:00
|
|
|
print("Batch C:")
|
2025-10-16 14:39:25 +00:00
|
|
|
batch_c = build_batch(paths, transforms.RandomVerticalFlip(1))
|
2025-10-14 08:53:19 +00:00
|
|
|
forward_pass(paths, batch_c, model)
|
2025-10-13 14:48:00 +02:00
|
|
|
|
2025-10-16 14:39:25 +00:00
|
|
|
|
2025-10-13 14:48:00 +02:00
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|