diff --git a/mmp/a1/tensors.py b/mmp/a1/tensors.py index 3d87203..72e1c29 100644 --- a/mmp/a1/tensors.py +++ b/mmp/a1/tensors.py @@ -7,6 +7,9 @@ def avg_color(img: torch.Tensor): def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float): mask = mask_tensor > threshold + if foreground.dim() == 3: + mask = mask.unsqueeze(0) + mask = mask.expand(foreground.shape) return torch.where(mask, foreground, background)