assignment-a1: implements mask function
This commit is contained in:
@@ -1,6 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
from torchvision.transforms import ToTensor
|
|
||||||
|
|
||||||
|
|
||||||
def avg_color(img: torch.Tensor):
|
def avg_color(img: torch.Tensor):
|
||||||
@@ -9,9 +7,9 @@ def avg_color(img: torch.Tensor):
|
|||||||
|
|
||||||
|
|
||||||
def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float):
|
def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float):
|
||||||
raise NotImplementedError()
|
mask = mask_tensor > threshold
|
||||||
|
return torch.where(mask, foreground, background)
|
||||||
|
|
||||||
|
|
||||||
def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor):
|
def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor):
|
||||||
return matrix.add(vector)
|
return matrix.add(vector)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user