main: fix

This commit is contained in:
franksim
2025-10-20 13:04:19 +00:00
parent d173712d9d
commit 61edc12522

View File

@@ -14,4 +14,4 @@ def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.
def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor):
return matrix.add(vector)
return matrix + vector.unsqueeze(1)