diff --git a/mmp/a1/tensors.py b/mmp/a1/tensors.py index f96fbac..3d87203 100644 --- a/mmp/a1/tensors.py +++ b/mmp/a1/tensors.py @@ -2,7 +2,7 @@ import torch def avg_color(img: torch.Tensor): - return img.mean(dim=(1, 2)).tolist() + return img.mean(dim=(1, 2)) def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float):