2025-10-13 14:48:00 +02:00
|
|
|
import torch
|
2025-10-14 09:37:51 +00:00
|
|
|
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
def avg_color(img: torch.Tensor):
|
2025-10-16 14:17:54 +00:00
|
|
|
return img.mean(dim=(1, 2))
|
2025-10-14 09:37:51 +00:00
|
|
|
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float):
|
2025-10-16 14:12:01 +00:00
|
|
|
mask = mask_tensor > threshold
|
2025-10-20 12:58:57 +00:00
|
|
|
if foreground.dim() == 3:
|
|
|
|
|
mask = mask.unsqueeze(0)
|
|
|
|
|
mask = mask.expand(foreground.shape)
|
2025-10-16 14:12:01 +00:00
|
|
|
return torch.where(mask, foreground, background)
|
2025-10-13 14:48:00 +02:00
|
|
|
|
2025-10-14 09:37:51 +00:00
|
|
|
|
2025-10-13 14:48:00 +02:00
|
|
|
def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor):
|
2025-10-14 09:37:51 +00:00
|
|
|
return matrix.add(vector)
|