From c68f30b15970dafd87e842f67347933d290520ea Mon Sep 17 00:00:00 2001 From: franksim Date: Thu, 16 Oct 2025 14:39:25 +0000 Subject: [PATCH] assignment-a1: adapts build batch --- mmp/a1/main.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/mmp/a1/main.py b/mmp/a1/main.py index c9eb169..cb10388 100644 --- a/mmp/a1/main.py +++ b/mmp/a1/main.py @@ -14,7 +14,7 @@ def pad_to_square(img): return F.pad(img, padding, fill=0, padding_mode='constant') -def build_batch(paths: Sequence[str], size=(224, 224), additional_transforms=[]) -> torch.Tensor: +def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor: """Exercise 1.1 @param paths: A sequence (e.g. list) of strings, each specifying the location of an image file. @@ -23,15 +23,13 @@ def build_batch(paths: Sequence[str], size=(224, 224), additional_transforms=[]) """ preprocess = transforms.Compose([ transforms.Lambda(pad_to_square), - transforms.Resize(size)] - + additional_transforms - + - [transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) + transforms.Resize((224, 224)), + *([transform] if transform is not None else []), + transforms.ToTensor() + ] + ) imgs = [] - + for path in paths: img = Image.open(path).convert('RGB') img = preprocess(img) @@ -86,13 +84,13 @@ def main(): forward_pass(paths, batch_a, model) print("Batch B:") - batch_b = build_batch(paths, (400, 400)) + batch_b = build_batch(paths, transforms.Resize((100, 100))) forward_pass(paths, batch_b, model) print("Batch C:") - batch_c = build_batch(paths, additional_transforms=[ - transforms.RandomVerticalFlip(1)]) + batch_c = build_batch(paths, transforms.RandomVerticalFlip(1)) forward_pass(paths, batch_c, model) + if __name__ == "__main__": main()