diff --git a/mmp/a1/tensors.py b/mmp/a1/tensors.py index 72e1c29..4c926c6 100644 --- a/mmp/a1/tensors.py +++ b/mmp/a1/tensors.py @@ -14,4 +14,4 @@ def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch. def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor): - return matrix.add(vector) + return matrix + vector.unsqueeze(1)