diff --git a/mmp/a1/tensors.py b/mmp/a1/tensors.py index 4574de1..0c7c329 100644 --- a/mmp/a1/tensors.py +++ b/mmp/a1/tensors.py @@ -1,10 +1,17 @@ import torch +from PIL import Image +from torchvision.transforms import ToTensor + def avg_color(img: torch.Tensor): - raise NotImplementedError() + result = img.mean(dim=(1, 2)).tolist() + return tuple(result) + def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float): raise NotImplementedError() + def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor): - raise NotImplementedError() + return matrix.add(vector) +