main: fix
This commit is contained in:
@@ -14,4 +14,4 @@ def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.
|
|||||||
|
|
||||||
|
|
||||||
def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor):
|
def add_matrix_vector(matrix: torch.Tensor, vector: torch.Tensor):
|
||||||
return matrix.add(vector)
|
return matrix + vector.unsqueeze(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user