formatting

This commit is contained in:
franksim
2025-11-07 11:20:08 +01:00
parent 8fc3559d6c
commit b159d76517
8 changed files with 119 additions and 82 deletions

View File

@@ -5,7 +5,12 @@ def avg_color(img: torch.Tensor):
return img.mean(dim=(1, 2))
def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float):
def mask(
foreground: torch.Tensor,
background: torch.Tensor,
mask_tensor: torch.Tensor,
threshold: float,
):
mask = mask_tensor > threshold
if foreground.dim() == 3:
mask = mask.unsqueeze(0)