Files
mmp_wise2526_franksim/mmp/a1/tensors.py

23 lines
514 B
Python
Raw Normal View History

2025-10-13 14:48:00 +02:00
import torch
2025-10-13 14:48:00 +02:00
def avg_color(img: torch.Tensor):
2025-10-16 14:17:54 +00:00
return img.mean(dim=(1, 2))
2025-10-13 14:48:00 +02:00
2025-11-07 11:20:08 +01:00
def mask(
foreground: torch.Tensor,
background: torch.Tensor,
mask_tensor: torch.Tensor,
threshold: float,
):
mask = mask_tensor > threshold
2025-10-20 12:58:57 +00:00
if foreground.dim() == 3:
mask = mask.unsqueeze(0)
mask = mask.expand(foreground.shape)
return torch.where(mask, foreground, background)
2025-10-13 14:48:00 +02:00
2025-10-13 14:48:00 +02:00
def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor):
2025-10-20 13:04:19 +00:00
return matrix + vector.unsqueeze(1)