From 9372abd0a3482c27cefe66a4613fa72e491d6bec Mon Sep 17 00:00:00 2001 From: franksim Date: Thu, 16 Oct 2025 14:12:01 +0000 Subject: [PATCH] assignment-a1: implements mask function --- mmp/a1/tensors.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mmp/a1/tensors.py b/mmp/a1/tensors.py index 0c7c329..4669751 100644 --- a/mmp/a1/tensors.py +++ b/mmp/a1/tensors.py @@ -1,6 +1,4 @@ import torch -from PIL import Image -from torchvision.transforms import ToTensor def avg_color(img: torch.Tensor): @@ -9,9 +7,9 @@ def avg_color(img: torch.Tensor): def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float): - raise NotImplementedError() + mask = mask_tensor > threshold + return torch.where(mask, foreground, background) def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor): return matrix.add(vector) -