From d173712d9d4e11d26f9d3dc98186e8beb156ddbe Mon Sep 17 00:00:00 2001 From: franksim Date: Mon, 20 Oct 2025 12:58:57 +0000 Subject: [PATCH] main: fix --- mmp/a1/tensors.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mmp/a1/tensors.py b/mmp/a1/tensors.py index 3d87203..72e1c29 100644 --- a/mmp/a1/tensors.py +++ b/mmp/a1/tensors.py @@ -7,6 +7,9 @@ def avg_color(img: torch.Tensor): def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float): mask = mask_tensor > threshold + if foreground.dim() == 3: + mask = mask.unsqueeze(0) + mask = mask.expand(foreground.shape) return torch.where(mask, foreground, background)