main: fix

This commit is contained in:
franksim
2025-10-20 12:58:57 +00:00
parent 7dd7537a4d
commit d173712d9d

View File

@@ -7,6 +7,9 @@ def avg_color(img: torch.Tensor):
def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float): def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float):
mask = mask_tensor > threshold mask = mask_tensor > threshold
if foreground.dim() == 3:
mask = mask.unsqueeze(0)
mask = mask.expand(foreground.shape)
return torch.where(mask, foreground, background) return torch.where(mask, foreground, background)