Added more comments.

This commit is contained in:
2023-09-03 10:06:46 +02:00
parent 427c75f1a6
commit 42546feae1
2 changed files with 54 additions and 4 deletions

View File

@@ -2,6 +2,12 @@
tensoriter_scalar tensoriter_scalar_create(tensor t) tensoriter_scalar tensoriter_scalar_create(tensor t)
{ {
/* Creates an iterator over the values of a tensor.
*
* @param t The tensor to iterate over
*
* @return The iterator
*/
assert(!tensor_is_empty(t)); assert(!tensor_is_empty(t));
tensoriter_scalar it = malloc(sizeof(struct _tensor_scalar_iterator)); tensoriter_scalar it = malloc(sizeof(struct _tensor_scalar_iterator));
@@ -15,29 +21,53 @@ tensoriter_scalar tensoriter_scalar_create(tensor t)
void tensoriter_scalar_destroy(tensoriter_scalar it) void tensoriter_scalar_destroy(tensoriter_scalar it)
{ {
/* Destroys an iterator.
*
* @param it The iterator to destroy
*/
free(it); free(it);
} }
uint8_t tensoriter_scalar_next(tensoriter_scalar it) bool tensoriter_scalar_next(tensoriter_scalar it)
{ {
/* Checks whether the given iterator has a next value and sets this value
* as the current value if available. If there is not next value the
* iterator is destroyed and false is returned.
*
* @param it The iterator to evaluate
*
* @return true if there is a next value, false otherwise
*/
if (it->length == 1) { if (it->length == 1) {
tensoriter_scalar_destroy(it); tensoriter_scalar_destroy(it);
return 0; return false;
} }
it->current++; it->current++;
it->length--; it->length--;
return 1; return true;
} }
dtype *tensoriter_scalar_get(tensoriter_scalar it) dtype *tensoriter_scalar_get(tensoriter_scalar it)
{ {
/* Gets the current value of the iterator.
*
* @param it The iterator to operate on
*
* @return A pointer to the current value
*/
return it->current; return it->current;
} }
void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype)) void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype))
{ {
/* Replaces every value in an iterator with the result of given function
* with the old value as a parameter.
*
* @param it The iterator to operate on
* @param func The map function that is called
*/
do { do {
dtype *el = tensoriter_scalar_get(it); dtype *el = tensoriter_scalar_get(it);
*el = func(*el); *el = func(*el);
@@ -46,6 +76,11 @@ void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype))
void tensoriter_scalar_map_add(tensoriter_scalar it, dtype scalar) void tensoriter_scalar_map_add(tensoriter_scalar it, dtype scalar)
{ {
/* Adds a fixed scalar value to all the values of the iterator.
*
* @param it The iterator to operate on
* @param scalar The value to add
*/
do { do {
DTYPE_ADD(*tensoriter_scalar_get(it), scalar); DTYPE_ADD(*tensoriter_scalar_get(it), scalar);
} while(tensoriter_scalar_next(it)); } while(tensoriter_scalar_next(it));
@@ -53,6 +88,11 @@ void tensoriter_scalar_map_add(tensoriter_scalar it, dtype scalar)
void tensoriter_scalar_map_sub(tensoriter_scalar it, dtype scalar) void tensoriter_scalar_map_sub(tensoriter_scalar it, dtype scalar)
{ {
/* Subtracts a fixed scalar value from all the values of the iterator.
*
* @param it The iterator to operate on
* @param scalar The value to subtract
*/
do { do {
DTYPE_SUB(*tensoriter_scalar_get(it), scalar); DTYPE_SUB(*tensoriter_scalar_get(it), scalar);
} while(tensoriter_scalar_next(it)); } while(tensoriter_scalar_next(it));
@@ -60,6 +100,11 @@ void tensoriter_scalar_map_sub(tensoriter_scalar it, dtype scalar)
void tensoriter_scalar_map_mul(tensoriter_scalar it, dtype scalar) void tensoriter_scalar_map_mul(tensoriter_scalar it, dtype scalar)
{ {
/* Multiplies a fixed scalar value with all the values of the iterator.
*
* @param it The iterator to operate on
* @param scalar The value to multiply
*/
do { do {
DTYPE_MUL(*tensoriter_scalar_get(it), scalar); DTYPE_MUL(*tensoriter_scalar_get(it), scalar);
} while(tensoriter_scalar_next(it)); } while(tensoriter_scalar_next(it));
@@ -67,6 +112,11 @@ void tensoriter_scalar_map_mul(tensoriter_scalar it, dtype scalar)
void tensoriter_scalar_map_div(tensoriter_scalar it, dtype scalar) void tensoriter_scalar_map_div(tensoriter_scalar it, dtype scalar)
{ {
/* Divides all the values of the iterator by a fixed scalar value.
*
* @param it The iterator to operate on
* @param scalar The value to divide by
*/
do { do {
DTYPE_DIV(*tensoriter_scalar_get(it), scalar); DTYPE_DIV(*tensoriter_scalar_get(it), scalar);
} while(tensoriter_scalar_next(it)); } while(tensoriter_scalar_next(it));

View File

@@ -11,7 +11,7 @@ typedef struct _tensor_scalar_iterator {
tensoriter_scalar tensoriter_scalar_create(tensor t); tensoriter_scalar tensoriter_scalar_create(tensor t);
void tensoriter_scalar_destroy(tensoriter_scalar it); void tensoriter_scalar_destroy(tensoriter_scalar it);
uint8_t tensoriter_scalar_next(tensoriter_scalar it); bool tensoriter_scalar_next(tensoriter_scalar it);
dtype *tensoriter_scalar_get(tensoriter_scalar it); dtype *tensoriter_scalar_get(tensoriter_scalar it);
void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype)); void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype));