diff --git a/tensoriter.c b/tensoriter.c index bcd5f96..93fa85e 100644 --- a/tensoriter.c +++ b/tensoriter.c @@ -49,7 +49,7 @@ bool tensoriter_scalar_next(tensoriter_scalar it) return true; } -dtype *tensoriter_scalar_get(tensoriter_scalar it) +dtype tensoriter_scalar_get(tensoriter_scalar it) { /* Gets the current value of the iterator. * @@ -57,7 +57,18 @@ dtype *tensoriter_scalar_get(tensoriter_scalar it) * * @return A pointer to the current value */ - return it->current; + return *(it->current); +} + +void tensoriter_scalar_set(tensoriter_scalar it, dtype value) +{ + /* Sets the current value of the iterator. + * + * @param it The iterator to operate on + * @param value The value to insert + * + */ + *(it->current) = value; } void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype)) @@ -69,8 +80,8 @@ void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype)) * @param func The map function that is called */ do { - dtype *el = tensoriter_scalar_get(it); - *el = func(*el); + dtype x = tensoriter_scalar_get(it); + tensoriter_scalar_set(it, func(x)); } while(tensoriter_scalar_next(it)); } @@ -82,8 +93,8 @@ void tensoriter_scalar_map_add(tensoriter_scalar it, dtype scalar) * @param scalar The value to add */ do { - dtype* x = tensoriter_scalar_get(it); - *x = DTYPE_ADD(*x, scalar); + dtype x = tensoriter_scalar_get(it); + tensoriter_scalar_set(it, DTYPE_ADD(x, scalar)); } while(tensoriter_scalar_next(it)); } @@ -95,8 +106,8 @@ void tensoriter_scalar_map_sub(tensoriter_scalar it, dtype scalar) * @param scalar The value to subtract */ do { - dtype* x = tensoriter_scalar_get(it); - *x = DTYPE_SUB(*x, scalar); + dtype x = tensoriter_scalar_get(it); + tensoriter_scalar_set(it, DTYPE_SUB(x, scalar)); } while(tensoriter_scalar_next(it)); } @@ -108,8 +119,8 @@ void tensoriter_scalar_map_mul(tensoriter_scalar it, dtype scalar) * @param scalar The value to multiply */ do { - dtype* x = tensoriter_scalar_get(it); - *x = DTYPE_MUL(*x, scalar); + dtype x = tensoriter_scalar_get(it); + tensoriter_scalar_set(it, DTYPE_MUL(x, scalar)); } while(tensoriter_scalar_next(it)); } @@ -121,8 +132,8 @@ void tensoriter_scalar_map_div(tensoriter_scalar it, dtype scalar) * @param scalar The value to divide by */ do { - dtype* x = tensoriter_scalar_get(it); - *x = DTYPE_DIV(*x, scalar); + dtype x = tensoriter_scalar_get(it); + tensoriter_scalar_set(it, DTYPE_DIV(x, scalar)); } while(tensoriter_scalar_next(it)); } diff --git a/tensoriter.h b/tensoriter.h index 2de276c..a5b9355 100644 --- a/tensoriter.h +++ b/tensoriter.h @@ -12,7 +12,8 @@ tensoriter_scalar tensoriter_scalar_create(tensor t); void tensoriter_scalar_destroy(tensoriter_scalar it); bool tensoriter_scalar_next(tensoriter_scalar it); -dtype *tensoriter_scalar_get(tensoriter_scalar it); +dtype tensoriter_scalar_get(tensoriter_scalar it); +void tensoriter_scalar_set(tensoriter_scalar it, dtype value); void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype)); void tensoriter_scalar_map_add(tensoriter_scalar it, dtype scalar); diff --git a/tests/main.c b/tests/main.c index f62f476..e76edc5 100644 --- a/tests/main.c +++ b/tests/main.c @@ -44,6 +44,7 @@ void tensoriter_test_run_all(void) void (*test_func[NUM_TENSORITER_TEST_FUNC])(void) = { &tensoriter_test_scalar_next, &tensoriter_test_scalar_get, + &tensoriter_test_scalar_set, &tensoriter_test_scalar_map, &tensoriter_test_scalar_map_add, &tensoriter_test_scalar_map_sub, diff --git a/tests/main.h b/tests/main.h index 9a820a7..ecd5bdb 100644 --- a/tests/main.h +++ b/tests/main.h @@ -6,7 +6,7 @@ #include "tensoriter_test.h" #define NUM_TENSOR_TEST_FUNC 18 -#define NUM_TENSORITER_TEST_FUNC 7 +#define NUM_TENSORITER_TEST_FUNC 8 void tensor_test_run_all(void); diff --git a/tests/tensoriter_test.c b/tests/tensoriter_test.c index b72d2f9..9ddce68 100644 --- a/tests/tensoriter_test.c +++ b/tests/tensoriter_test.c @@ -22,8 +22,8 @@ void tensoriter_test_scalar_next(void) void tensoriter_test_scalar_get(void) { /* Depends on tensor_init_one, tensor_set, tensoriter_scalar_next */ - uint32_t s[4] = {2, 4, 4}; - uint32_t index[4] = {0, 0, 0}; + uint32_t s[3] = {2, 4, 4}; + uint32_t index[3] = {0, 0, 0}; tensor t = tensor_new(); tensor_init_one(t, s, 3); @@ -39,7 +39,7 @@ void tensoriter_test_scalar_get(void) uint32_t contained = 0; tensoriter_scalar iter = tensoriter_scalar_create(t); do { - uint32_t value = (uint32_t) *tensoriter_scalar_get(iter); + uint32_t value = (uint32_t) tensoriter_scalar_get(iter); tensor_assert(((1 << (value - 1)) & contained) == 0, "mute"); contained |= 1 << (value - 1); } while (tensoriter_scalar_next(iter)); @@ -48,6 +48,25 @@ void tensoriter_test_scalar_get(void) tensor_destroy(t); } +void tensoriter_test_scalar_set(void) +{ + /* Depends on tensor_init_one, tensor_init_rand, tensoriter_scalar_next */ + uint32_t s[3] = {2, 4, 4}; + + tensor t1 = tensor_new(); + tensor t2 = tensor_new(); + + tensor_init_one(t1, s, 3); + tensor_init_rand(t2, s, 3, (dtype) 30); + + tensoriter_scalar iter = tensoriter_scalar_create(t2); + do { + tensoriter_scalar_set(iter, DTYPE_ONE); + } while (tensoriter_scalar_next(iter)); + + tensor_assert_eq(t1, t2); +} + dtype tensoriter_test_scalar_map_helper(dtype d) { static uint32_t contained = 0; if(((1 << ((uint32_t) d - 1)) & contained) == 0) { @@ -61,8 +80,8 @@ dtype tensoriter_test_scalar_map_helper(dtype d) { void tensoriter_test_scalar_map(void) { /* Depends on tensor_init_one, tensor_init_zero, tensor_set*/ - uint32_t s[4] = {2, 4, 4}; - uint32_t index[4] = {0, 0, 0}; + uint32_t s[3] = {2, 4, 4}; + uint32_t index[3] = {0, 0, 0}; tensor t = tensor_new(); tensor t0 = tensor_new(); diff --git a/tests/tensoriter_test.h b/tests/tensoriter_test.h index 34c14d7..de0065c 100644 --- a/tests/tensoriter_test.h +++ b/tests/tensoriter_test.h @@ -7,6 +7,7 @@ void tensoriter_test_scalar_next(void); void tensoriter_test_scalar_get(void); +void tensoriter_test_scalar_set(void); void tensoriter_test_scalar_map(void); void tensoriter_test_scalar_map_add(void); void tensoriter_test_scalar_map_sub(void);