diff --git a/mmp/a1/main.py b/mmp/a1/main.py index fd807bf..e134645 100644 --- a/mmp/a1/main.py +++ b/mmp/a1/main.py @@ -1,5 +1,6 @@ from typing import Sequence import torch +import torchvision from torchvision.transforms import functional as F from torchvision import transforms, models from PIL import Image @@ -51,7 +52,31 @@ def main(): Put all your code for exercise 1.3 here. """ - raise NotImplementedError() + + paths = [ + "./images/golden retriever.jpg", + "./images/koala.jpg", + "./images/pacifier.jpg", + "./images/rubber duck sculpture.jpg", + "./images/rubber ducks.jpg", + "./images/shoehorn.jpg", + "./images/zoo.jpg", + ] + batch = build_batch(paths) + model = get_model() + + with torch.no_grad(): + outputs = model(batch) + + max_scores, preds = outputs.max(dim=1) + + class_names = torchvision.models.ResNet18_Weights.DEFAULT.meta["categories"] + + for i, (p, s) in enumerate(zip(preds, max_scores)): + print(f"Image: {paths[i]}") + print(f" Model output score: {s.item():.4f}") + print(f" Predicted class: {class_names[p.item()]}") + print() if __name__ == "__main__":