Files
mmp_wise2526_franksim/mmp/a1/tensors.py
2025-10-20 13:04:19 +00:00

18 lines
495 B
Python

import torch
def avg_color(img: torch.Tensor):
return img.mean(dim=(1, 2))
def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float):
mask = mask_tensor > threshold
if foreground.dim() == 3:
mask = mask.unsqueeze(0)
mask = mask.expand(foreground.shape)
return torch.where(mask, foreground, background)
def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor):
return matrix + vector.unsqueeze(1)