diff --git a/tensor.c b/tensor.c index 8888dc7..92886c4 100644 --- a/tensor.c +++ b/tensor.c @@ -346,6 +346,31 @@ bool tensor_mul_inplace(tensor t1, const tensor t2) return true; } +bool tensor_div_inplace(tensor t1, const tensor t2) +{ + /* Divides the values of t2 by the values of t1 element wise. t1 and t2 + * need to have the same shape. + * + * @param t1 The tensor to devide + * @param t2 The tensor to devide by + * + * @return true when successful, false otherwise + */ + assert(!tensor_is_empty(t1)); + assert(!tensor_is_empty(t2)); + + uint32_t i; + + if(t1->rank != t2->rank) return false; + for(i = 0; i < t1->rank; i++) { + if(t1->size[i] != t2->size[i]) return false; + } + for(i = 0; i < t1->num_elem; i++) { + t1->elements[i] = DTYPE_DIV(t1->elements[i], t2->elements[i]); + } + return true; +} + tensor tensor_add(const tensor t1, const tensor t2) { /* Adds two tensors returning the result as a tensor. t1 and t2 need to @@ -425,6 +450,32 @@ tensor tensor_mul(const tensor t1, const tensor t2) return t3; } +tensor tensor_div(const tensor t1, const tensor t2) +{ + /* Divides two tensors element wise returning the result as a tensor. t1 + * and t2 need to have the same shape. + * + * @param t1 The dividend + * @param t2 The divisor + * + * @return The result of the operation, NULL if an error occurs + */ + assert(!tensor_is_empty(t1)); + assert(!tensor_is_empty(t2)); + + tensor t3 = tensor_new(); + if(t3 == NULL) return NULL; + if (!tensor_cpy(t3, t1)) { + tensor_destroy(t3); + return NULL; + } + if (!tensor_div_inplace(t3, t2)) { + tensor_destroy(t3); + return NULL; + } + return t3; +} + void tensor_print(const tensor t) { /* Prints a tensor to stdout. diff --git a/tensor.h b/tensor.h index 96a1438..02b682e 100644 --- a/tensor.h +++ b/tensor.h @@ -41,9 +41,11 @@ bool tensor_cpy(tensor t1, const tensor t2); bool tensor_add_inplace(tensor t1, const tensor t2); bool tensor_sub_inplace(tensor t1, const tensor t2); bool tensor_mul_inplace(tensor t1, const tensor t2); +bool tensor_div_inplace(tensor t1, const tensor t2); tensor tensor_add(const tensor t1, const tensor t2); tensor tensor_sub(const tensor t1, const tensor t2); tensor tensor_mul(const tensor t1, const tensor t2); +tensor tensor_div(const tensor t1, const tensor t2); void tensor_print(const tensor t);