From 0ef0eeb0479e28b41883078732447be813a97da3 Mon Sep 17 00:00:00 2001 From: Elias Kohout Date: Sat, 18 Mar 2023 20:56:40 +0100 Subject: [PATCH] Added tensor subtraction. --- tensor.c | 16 ++++++++++++++++ tensor.h | 1 + 2 files changed, 17 insertions(+) diff --git a/tensor.c b/tensor.c index 234b70b..4477d4b 100644 --- a/tensor.c +++ b/tensor.c @@ -238,6 +238,22 @@ int tensor_add(tensor t1, const tensor t2) return 1; } +int tensor_sub(tensor t1, const tensor t2) +{ + assert(!tensor_is_empty(t1)); + assert(!tensor_is_empty(t2)); + + int i; + if(t1->dimension != t2->dimension) return 0; + for(i = 0; i < t1->dimension; i++) { + if(t1->size[i] != t2->size[i]) return 0; + } + for(i = 0; i < t1->num_elem; i++) { + t1->elements[i] -= t2->elements[i]; + } + return 1; +} + void tensor_for_each_elem(tensor t, dtype (*func)(dtype)) { assert(!tensor_is_empty(t)); diff --git a/tensor.h b/tensor.h index c53e7e1..bf29307 100644 --- a/tensor.h +++ b/tensor.h @@ -42,6 +42,7 @@ void tensor_sub_scalar(tensor t, dtype n); void tensor_mult_scalar(tensor t, dtype n); void tensor_div_scalar(tensor t, dtype n); int tensor_add(tensor t1, const tensor t2); +int tensor_sub(tensor t1, const tensor t2); void tensor_for_each_elem(tensor t, dtype (*func)(dtype)); void tensor_print(const tensor t);