diff --git a/tensoriter.c b/tensoriter.c index 93fa85e..a782b76 100644 --- a/tensoriter.c +++ b/tensoriter.c @@ -2,19 +2,20 @@ tensoriter_scalar tensoriter_scalar_create(tensor t) { - /* Creates an iterator over the values of a tensor. + /* Creates an iterator over the values of a tensor. If two tensors have the + * same size the iterator will always iterate over them in the same order. * * @param t The tensor to iterate over * - * @return The iterator + * @return The iterator, NULL in case of an error */ assert(!tensor_is_empty(t)); tensoriter_scalar it = malloc(sizeof(struct _tensor_scalar_iterator)); if (it == NULL) return NULL; - it->current = t->elements; - it->length = t->num_elem; + it->t = t; + it->index = calloc(sizeof(uint32_t), t->rank); return it; } @@ -25,6 +26,7 @@ void tensoriter_scalar_destroy(tensoriter_scalar it) * * @param it The iterator to destroy */ + free(it->index); free(it); } @@ -38,37 +40,49 @@ bool tensoriter_scalar_next(tensoriter_scalar it) * * @return true if there is a next value, false otherwise */ - if (it->length == 1) { + bool end = true; + + for (uint8_t i = 0; i < it->t->rank; i++) { + if (it->index[i] < it->t->size[i] - 1) { + (it->index[i])++; + end = false; + break; + } else { + it->index[i] = 0; + } + } + + if (end) { tensoriter_scalar_destroy(it); return false; } - it->current++; - it->length--; - return true; } -dtype tensoriter_scalar_get(tensoriter_scalar it) +dtype tensoriter_scalar_get(tensoriter_scalar it, bool *success) { /* Gets the current value of the iterator. * * @param it The iterator to operate on + * @param success Is set if not NULL and defines whether the get operation + * was successful * - * @return A pointer to the current value + * @return The current value of the iterator */ - return *(it->current); + return tensor_get(it->t, it->index, success); } -void tensoriter_scalar_set(tensoriter_scalar it, dtype value) +bool tensoriter_scalar_set(tensoriter_scalar it, dtype value) { - /* Sets the current value of the iterator. + /* Sets the value of the tensor which the iterator is pointing to at the + * moment. * * @param it The iterator to operate on * @param value The value to insert * */ - *(it->current) = value; + return tensor_set(it->t, it->index, value); } void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype)) @@ -80,7 +94,7 @@ void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype)) * @param func The map function that is called */ do { - dtype x = tensoriter_scalar_get(it); + dtype x = tensoriter_scalar_get(it, NULL); tensoriter_scalar_set(it, func(x)); } while(tensoriter_scalar_next(it)); } @@ -93,7 +107,7 @@ void tensoriter_scalar_map_add(tensoriter_scalar it, dtype scalar) * @param scalar The value to add */ do { - dtype x = tensoriter_scalar_get(it); + dtype x = tensoriter_scalar_get(it, NULL); tensoriter_scalar_set(it, DTYPE_ADD(x, scalar)); } while(tensoriter_scalar_next(it)); } @@ -106,7 +120,7 @@ void tensoriter_scalar_map_sub(tensoriter_scalar it, dtype scalar) * @param scalar The value to subtract */ do { - dtype x = tensoriter_scalar_get(it); + dtype x = tensoriter_scalar_get(it, NULL); tensoriter_scalar_set(it, DTYPE_SUB(x, scalar)); } while(tensoriter_scalar_next(it)); } @@ -119,7 +133,7 @@ void tensoriter_scalar_map_mul(tensoriter_scalar it, dtype scalar) * @param scalar The value to multiply */ do { - dtype x = tensoriter_scalar_get(it); + dtype x = tensoriter_scalar_get(it, NULL); tensoriter_scalar_set(it, DTYPE_MUL(x, scalar)); } while(tensoriter_scalar_next(it)); } @@ -132,7 +146,7 @@ void tensoriter_scalar_map_div(tensoriter_scalar it, dtype scalar) * @param scalar The value to divide by */ do { - dtype x = tensoriter_scalar_get(it); + dtype x = tensoriter_scalar_get(it, NULL); tensoriter_scalar_set(it, DTYPE_DIV(x, scalar)); } while(tensoriter_scalar_next(it)); } diff --git a/tensoriter.h b/tensoriter.h index a5b9355..5b4852e 100644 --- a/tensoriter.h +++ b/tensoriter.h @@ -4,16 +4,16 @@ #include "tensor.h" typedef struct _tensor_scalar_iterator { - dtype *current; - uint32_t length; + tensor t; + uint32_t *index; } * tensoriter_scalar; tensoriter_scalar tensoriter_scalar_create(tensor t); void tensoriter_scalar_destroy(tensoriter_scalar it); bool tensoriter_scalar_next(tensoriter_scalar it); -dtype tensoriter_scalar_get(tensoriter_scalar it); -void tensoriter_scalar_set(tensoriter_scalar it, dtype value); +dtype tensoriter_scalar_get(tensoriter_scalar it, bool *success); +bool tensoriter_scalar_set(tensoriter_scalar it, dtype value); void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype)); void tensoriter_scalar_map_add(tensoriter_scalar it, dtype scalar); diff --git a/tests/tensoriter_test.c b/tests/tensoriter_test.c index 9ddce68..1d58d5c 100644 --- a/tests/tensoriter_test.c +++ b/tests/tensoriter_test.c @@ -39,7 +39,9 @@ void tensoriter_test_scalar_get(void) uint32_t contained = 0; tensoriter_scalar iter = tensoriter_scalar_create(t); do { - uint32_t value = (uint32_t) tensoriter_scalar_get(iter); + bool success; + uint32_t value = (uint32_t) tensoriter_scalar_get(iter, &success); + tensor_assert(success, "mute"); tensor_assert(((1 << (value - 1)) & contained) == 0, "mute"); contained |= 1 << (value - 1); } while (tensoriter_scalar_next(iter));