formatting
This commit is contained in:
@@ -5,7 +5,12 @@ def avg_color(img: torch.Tensor):
|
||||
return img.mean(dim=(1, 2))
|
||||
|
||||
|
||||
def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float):
|
||||
def mask(
|
||||
foreground: torch.Tensor,
|
||||
background: torch.Tensor,
|
||||
mask_tensor: torch.Tensor,
|
||||
threshold: float,
|
||||
):
|
||||
mask = mask_tensor > threshold
|
||||
if foreground.dim() == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
Reference in New Issue
Block a user