diff --git a/tensor.c b/tensor.c index e34b127..d0ae2e2 100644 --- a/tensor.c +++ b/tensor.c @@ -1,6 +1,10 @@ #include "tensor.h" #include +/* A helper for f.e. the tensor_add_scalar function */ +dtype _tensor_scalar_wise_helper; + + tensor tensor_new(void) { return calloc(1, sizeof(struct _tensor)); @@ -183,44 +187,42 @@ int tensor_cpy(tensor t1, const tensor t2) return 1; } + +dtype _tensor_dtype_add_helper(dtype d) { return DTYPE_ADD(d, _tensor_scalar_wise_helper); } +dtype _tensor_dtype_sub_helper(dtype d) { return DTYPE_SUB(d, _tensor_scalar_wise_helper); } +dtype _tensor_dtype_mul_helper(dtype d) { return DTYPE_MUL(d, _tensor_scalar_wise_helper); } +dtype _tensor_dtype_div_helper(dtype d) { return DTYPE_DIV(d, _tensor_scalar_wise_helper); } + void tensor_add_scalar(tensor t, dtype n) { assert(!tensor_is_empty(t)); - int i; - for(i = 0; i < t->num_elem; i++) { - t->elements[i] = DTYPE_ADD(t->elements[i], n); - } + _tensor_scalar_wise_helper = n; + tensor_map(t, &_tensor_dtype_add_helper); } void tensor_sub_scalar(tensor t, dtype n) { assert(!tensor_is_empty(t)); - int i; - for(i = 0; i < t->num_elem; i++) { - t->elements[i] = DTYPE_SUB(t->elements[i], n); - } + _tensor_scalar_wise_helper = n; + tensor_map(t, &_tensor_dtype_sub_helper); } -void tensor_mult_scalar(tensor t, dtype n) +void tensor_mul_scalar(tensor t, dtype n) { assert(!tensor_is_empty(t)); - int i; - for(i = 0; i < t->num_elem; i++) { - t->elements[i] = DTYPE_MUL(t->elements[i], n); - } + _tensor_scalar_wise_helper = n; + tensor_map(t, &_tensor_dtype_mul_helper); } void tensor_div_scalar(tensor t, dtype n) { assert(!tensor_is_empty(t)); - int i; - for(i = 0; i < t->num_elem; i++) { - t->elements[i] = DTYPE_DIV(t->elements[i], n); - } + _tensor_scalar_wise_helper = n; + tensor_map(t, &_tensor_dtype_div_helper); } int tensor_add_inplace(tensor t1, const tensor t2) diff --git a/tensor.h b/tensor.h index 4e0a588..343a9c5 100644 --- a/tensor.h +++ b/tensor.h @@ -6,6 +6,7 @@ #include #include #include +#include /* Defining the datatype of the tensor */ typedef float dtype; @@ -53,10 +54,12 @@ int tensor_cpy(tensor t1, const tensor t2); void tensor_add_scalar(tensor t, dtype n); void tensor_sub_scalar(tensor t, dtype n); -void tensor_mult_scalar(tensor t, dtype n); +void tensor_mul_scalar(tensor t, dtype n); void tensor_div_scalar(tensor t, dtype n); int tensor_add_inplace(tensor t1, const tensor t2); int tensor_sub_inplace(tensor t1, const tensor t2); +tensor tensor_add(const tensor t1, const tensor t2); +tensor tensor_sub(const tensor t1, const tensor t2); void tensor_map(tensor t, dtype (*func)(dtype)); void tensor_print(const tensor t);