diff --git a/tensor.c b/tensor.c index 234b70b..4477d4b 100644 --- a/tensor.c +++ b/tensor.c @@ -238,6 +238,22 @@ int tensor_add(tensor t1, const tensor t2) return 1; } +int tensor_sub(tensor t1, const tensor t2) +{ + assert(!tensor_is_empty(t1)); + assert(!tensor_is_empty(t2)); + + int i; + if(t1->dimension != t2->dimension) return 0; + for(i = 0; i < t1->dimension; i++) { + if(t1->size[i] != t2->size[i]) return 0; + } + for(i = 0; i < t1->num_elem; i++) { + t1->elements[i] -= t2->elements[i]; + } + return 1; +} + void tensor_for_each_elem(tensor t, dtype (*func)(dtype)) { assert(!tensor_is_empty(t)); diff --git a/tensor.h b/tensor.h index c53e7e1..bf29307 100644 --- a/tensor.h +++ b/tensor.h @@ -42,6 +42,7 @@ void tensor_sub_scalar(tensor t, dtype n); void tensor_mult_scalar(tensor t, dtype n); void tensor_div_scalar(tensor t, dtype n); int tensor_add(tensor t1, const tensor t2); +int tensor_sub(tensor t1, const tensor t2); void tensor_for_each_elem(tensor t, dtype (*func)(dtype)); void tensor_print(const tensor t);