assignment-a1: adds forward pass

This commit is contained in:
franksim
2025-10-14 07:45:19 +00:00
parent 142b93a119
commit 19ea10dbe7

View File

@@ -1,5 +1,6 @@
from typing import Sequence from typing import Sequence
import torch import torch
import torchvision
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from torchvision import transforms, models from torchvision import transforms, models
from PIL import Image from PIL import Image
@@ -51,7 +52,31 @@ def main():
Put all your code for exercise 1.3 here. Put all your code for exercise 1.3 here.
""" """
raise NotImplementedError()
paths = [
"./images/golden retriever.jpg",
"./images/koala.jpg",
"./images/pacifier.jpg",
"./images/rubber duck sculpture.jpg",
"./images/rubber ducks.jpg",
"./images/shoehorn.jpg",
"./images/zoo.jpg",
]
batch = build_batch(paths)
model = get_model()
with torch.no_grad():
outputs = model(batch)
max_scores, preds = outputs.max(dim=1)
class_names = torchvision.models.ResNet18_Weights.DEFAULT.meta["categories"]
for i, (p, s) in enumerate(zip(preds, max_scores)):
print(f"Image: {paths[i]}")
print(f" Model output score: {s.item():.4f}")
print(f" Predicted class: {class_names[p.item()]}")
print()
if __name__ == "__main__": if __name__ == "__main__":