diff --git a/tensor.c b/tensor.c index d38b8d1..f5df9e4 100644 --- a/tensor.c +++ b/tensor.c @@ -1,8 +1,5 @@ #include "tensor.h" - -/* A helper for f.e. the tensor_add_scalar function */ -dtype _tensor_scalar_wise_helper; - +#include "tensoriterator.h" tensor tensor_new(void) { @@ -189,41 +186,40 @@ uint8_t tensor_cpy(tensor t1, const tensor t2) return 1; } -dtype _tensor_dtype_add_helper(dtype d) { return DTYPE_ADD(d, _tensor_scalar_wise_helper); } -dtype _tensor_dtype_sub_helper(dtype d) { return DTYPE_SUB(d, _tensor_scalar_wise_helper); } -dtype _tensor_dtype_mul_helper(dtype d) { return DTYPE_MUL(d, _tensor_scalar_wise_helper); } -dtype _tensor_dtype_div_helper(dtype d) { return DTYPE_DIV(d, _tensor_scalar_wise_helper); } - void tensor_add_scalar(tensor t, dtype n) { assert(!tensor_is_empty(t)); - _tensor_scalar_wise_helper = n; - tensor_map(t, &_tensor_dtype_add_helper); + tensoriter_scalar iter = tensoriter_scalar_create(t); + tensoriter_scalar_map_add(iter, n); + tensoriter_scalar_destroy(iter); } void tensor_sub_scalar(tensor t, dtype n) { assert(!tensor_is_empty(t)); - _tensor_scalar_wise_helper = n; - tensor_map(t, &_tensor_dtype_sub_helper); + tensoriter_scalar iter = tensoriter_scalar_create(t); + tensoriter_scalar_map_sub(iter, n); + tensoriter_scalar_destroy(iter); } void tensor_mul_scalar(tensor t, dtype n) { assert(!tensor_is_empty(t)); - _tensor_scalar_wise_helper = n; - tensor_map(t, &_tensor_dtype_mul_helper); + tensoriter_scalar iter = tensoriter_scalar_create(t); + tensoriter_scalar_map_mul(iter, n); + tensoriter_scalar_destroy(iter); } void tensor_div_scalar(tensor t, dtype n) { assert(!tensor_is_empty(t)); - _tensor_scalar_wise_helper = n; - tensor_map(t, &_tensor_dtype_div_helper); + tensoriter_scalar iter = tensoriter_scalar_create(t); + tensoriter_scalar_map_div(iter, n); + tensoriter_scalar_destroy(iter); } uint8_t tensor_add_inplace(tensor t1, const tensor t2) @@ -283,16 +279,6 @@ tensor tensor_sub(const tensor t1, const tensor t2) } -void tensor_map(tensor t, dtype (*func)(dtype)) -{ - assert(!tensor_is_empty(t)); - - uint32_t i; - for(i = 0; i < t->num_elem; i++) { - t->elements[i] = func(t->elements[i]); - } -} - void tensor_print(const tensor t) { uint32_t i, j; diff --git a/tensor.h b/tensor.h index f95ea25..0a6f2ff 100644 --- a/tensor.h +++ b/tensor.h @@ -64,7 +64,6 @@ uint8_t tensor_sub_inplace(tensor t1, const tensor t2); tensor tensor_add(const tensor t1, const tensor t2); tensor tensor_sub(const tensor t1, const tensor t2); -void tensor_map(tensor t, dtype (*func)(dtype)); void tensor_print(const tensor t); #endif diff --git a/tensoriterator.c b/tensoriterator.c new file mode 100644 index 0000000..dcff0b6 --- /dev/null +++ b/tensoriterator.c @@ -0,0 +1,74 @@ +#include "tensoriterator.h" + +tensoriter_scalar tensoriter_scalar_create(tensor t) +{ + assert(!tensor_is_empty(t)); + + tensoriter_scalar it = malloc(sizeof(struct _tensor_scalar_iterator)); + if (it == NULL) return NULL; + + it->current = t->elements; + it->length = t->num_elem; + + return it; +} + +void tensoriter_scalar_destroy(tensoriter_scalar it) +{ + free(it); +} + +uint8_t tensoriter_scalar_next(tensoriter_scalar it) +{ + if (it->length == 1) { + tensoriter_scalar_destroy(it); + return 0; + } + + it->current++; + it->length--; + + return 1; +} + +dtype *tensoriter_scalar_get(tensoriter_scalar it) +{ + return it->current; +} + +void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype)) +{ + do { + dtype el = *tensoriter_scalar_get(it); + el = func(el); + } while(tensoriter_scalar_next(it)); +} + +void tensoriter_scalar_map_add(tensoriter_scalar it, dtype scalar) +{ + do { + DTYPE_ADD(*tensoriter_scalar_get(it), scalar); + } while(tensoriter_scalar_next(it)); +} + +void tensoriter_scalar_map_sub(tensoriter_scalar it, dtype scalar) +{ + do { + DTYPE_SUB(*tensoriter_scalar_get(it), scalar); + } while(tensoriter_scalar_next(it)); +} + +void tensoriter_scalar_map_mul(tensoriter_scalar it, dtype scalar) +{ + do { + DTYPE_MUL(*tensoriter_scalar_get(it), scalar); + } while(tensoriter_scalar_next(it)); +} + +void tensoriter_scalar_map_div(tensoriter_scalar it, dtype scalar) +{ + do { + DTYPE_DIV(*tensoriter_scalar_get(it), scalar); + } while(tensoriter_scalar_next(it)); +} + diff --git a/tensoriterator.h b/tensoriterator.h new file mode 100644 index 0000000..ee2910d --- /dev/null +++ b/tensoriterator.h @@ -0,0 +1,25 @@ +#ifndef _TENSORITERATOR_H_ +#define _TENSORITERATOR_H_ + +#include "tensor.h" + +typedef struct _tensor_scalar_iterator { + dtype *current; + uint32_t length; +} * tensoriter_scalar; + +tensoriter_scalar tensoriter_scalar_create(tensor t); +void tensoriter_scalar_reset(tensoriter_scalar it); +void tensoriter_scalar_destroy(tensoriter_scalar it); + +uint8_t tensoriter_scalar_next(tensoriter_scalar it); +dtype *tensoriter_scalar_get(tensoriter_scalar it); + +void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype)); +void tensoriter_scalar_map_add(tensoriter_scalar it, dtype scalar); +void tensoriter_scalar_map_sub(tensoriter_scalar it, dtype scalar); +void tensoriter_scalar_map_mul(tensoriter_scalar it, dtype scalar); +void tensoriter_scalar_map_div(tensoriter_scalar it, dtype scalar); + +#endif // _TENSORITERATOR_H_ +