assignment-a1: implements mask function

This commit is contained in:
franksim
2025-10-16 14:12:01 +00:00
parent b7825b0831
commit 9372abd0a3

View File

@@ -1,6 +1,4 @@
import torch import torch
from PIL import Image
from torchvision.transforms import ToTensor
def avg_color(img: torch.Tensor): def avg_color(img: torch.Tensor):
@@ -9,9 +7,9 @@ def avg_color(img: torch.Tensor):
def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float): def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float):
raise NotImplementedError() mask = mask_tensor > threshold
return torch.where(mask, foreground, background)
def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor): def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor):
return matrix.add(vector) return matrix.add(vector)