diff --git a/mmp/a1/tensors.py b/mmp/a1/tensors.py index 0c7c329..4669751 100644 --- a/mmp/a1/tensors.py +++ b/mmp/a1/tensors.py @@ -1,6 +1,4 @@ import torch -from PIL import Image -from torchvision.transforms import ToTensor def avg_color(img: torch.Tensor): @@ -9,9 +7,9 @@ def avg_color(img: torch.Tensor): def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float): - raise NotImplementedError() + mask = mask_tensor > threshold + return torch.where(mask, foreground, background) def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor): return matrix.add(vector) -