Added iterators over scalars of tensors

This commit is contained in:
2023-05-07 16:58:49 +02:00
parent 78623d2e9d
commit 19c644f450
4 changed files with 112 additions and 28 deletions

View File

@@ -1,8 +1,5 @@
#include "tensor.h" #include "tensor.h"
#include "tensoriterator.h"
/* A helper for f.e. the tensor_add_scalar function */
dtype _tensor_scalar_wise_helper;
tensor tensor_new(void) tensor tensor_new(void)
{ {
@@ -189,41 +186,40 @@ uint8_t tensor_cpy(tensor t1, const tensor t2)
return 1; return 1;
} }
dtype _tensor_dtype_add_helper(dtype d) { return DTYPE_ADD(d, _tensor_scalar_wise_helper); }
dtype _tensor_dtype_sub_helper(dtype d) { return DTYPE_SUB(d, _tensor_scalar_wise_helper); }
dtype _tensor_dtype_mul_helper(dtype d) { return DTYPE_MUL(d, _tensor_scalar_wise_helper); }
dtype _tensor_dtype_div_helper(dtype d) { return DTYPE_DIV(d, _tensor_scalar_wise_helper); }
void tensor_add_scalar(tensor t, dtype n) void tensor_add_scalar(tensor t, dtype n)
{ {
assert(!tensor_is_empty(t)); assert(!tensor_is_empty(t));
_tensor_scalar_wise_helper = n; tensoriter_scalar iter = tensoriter_scalar_create(t);
tensor_map(t, &_tensor_dtype_add_helper); tensoriter_scalar_map_add(iter, n);
tensoriter_scalar_destroy(iter);
} }
void tensor_sub_scalar(tensor t, dtype n) void tensor_sub_scalar(tensor t, dtype n)
{ {
assert(!tensor_is_empty(t)); assert(!tensor_is_empty(t));
_tensor_scalar_wise_helper = n; tensoriter_scalar iter = tensoriter_scalar_create(t);
tensor_map(t, &_tensor_dtype_sub_helper); tensoriter_scalar_map_sub(iter, n);
tensoriter_scalar_destroy(iter);
} }
void tensor_mul_scalar(tensor t, dtype n) void tensor_mul_scalar(tensor t, dtype n)
{ {
assert(!tensor_is_empty(t)); assert(!tensor_is_empty(t));
_tensor_scalar_wise_helper = n; tensoriter_scalar iter = tensoriter_scalar_create(t);
tensor_map(t, &_tensor_dtype_mul_helper); tensoriter_scalar_map_mul(iter, n);
tensoriter_scalar_destroy(iter);
} }
void tensor_div_scalar(tensor t, dtype n) void tensor_div_scalar(tensor t, dtype n)
{ {
assert(!tensor_is_empty(t)); assert(!tensor_is_empty(t));
_tensor_scalar_wise_helper = n; tensoriter_scalar iter = tensoriter_scalar_create(t);
tensor_map(t, &_tensor_dtype_div_helper); tensoriter_scalar_map_div(iter, n);
tensoriter_scalar_destroy(iter);
} }
uint8_t tensor_add_inplace(tensor t1, const tensor t2) uint8_t tensor_add_inplace(tensor t1, const tensor t2)
@@ -283,16 +279,6 @@ tensor tensor_sub(const tensor t1, const tensor t2)
} }
void tensor_map(tensor t, dtype (*func)(dtype))
{
assert(!tensor_is_empty(t));
uint32_t i;
for(i = 0; i < t->num_elem; i++) {
t->elements[i] = func(t->elements[i]);
}
}
void tensor_print(const tensor t) void tensor_print(const tensor t)
{ {
uint32_t i, j; uint32_t i, j;

View File

@@ -64,7 +64,6 @@ uint8_t tensor_sub_inplace(tensor t1, const tensor t2);
tensor tensor_add(const tensor t1, const tensor t2); tensor tensor_add(const tensor t1, const tensor t2);
tensor tensor_sub(const tensor t1, const tensor t2); tensor tensor_sub(const tensor t1, const tensor t2);
void tensor_map(tensor t, dtype (*func)(dtype));
void tensor_print(const tensor t); void tensor_print(const tensor t);
#endif #endif

74
tensoriterator.c Normal file
View File

@@ -0,0 +1,74 @@
#include "tensoriterator.h"
tensoriter_scalar tensoriter_scalar_create(tensor t)
{
assert(!tensor_is_empty(t));
tensoriter_scalar it = malloc(sizeof(struct _tensor_scalar_iterator));
if (it == NULL) return NULL;
it->current = t->elements;
it->length = t->num_elem;
return it;
}
void tensoriter_scalar_destroy(tensoriter_scalar it)
{
free(it);
}
uint8_t tensoriter_scalar_next(tensoriter_scalar it)
{
if (it->length == 1) {
tensoriter_scalar_destroy(it);
return 0;
}
it->current++;
it->length--;
return 1;
}
dtype *tensoriter_scalar_get(tensoriter_scalar it)
{
return it->current;
}
void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype))
{
do {
dtype el = *tensoriter_scalar_get(it);
el = func(el);
} while(tensoriter_scalar_next(it));
}
void tensoriter_scalar_map_add(tensoriter_scalar it, dtype scalar)
{
do {
DTYPE_ADD(*tensoriter_scalar_get(it), scalar);
} while(tensoriter_scalar_next(it));
}
void tensoriter_scalar_map_sub(tensoriter_scalar it, dtype scalar)
{
do {
DTYPE_SUB(*tensoriter_scalar_get(it), scalar);
} while(tensoriter_scalar_next(it));
}
void tensoriter_scalar_map_mul(tensoriter_scalar it, dtype scalar)
{
do {
DTYPE_MUL(*tensoriter_scalar_get(it), scalar);
} while(tensoriter_scalar_next(it));
}
void tensoriter_scalar_map_div(tensoriter_scalar it, dtype scalar)
{
do {
DTYPE_DIV(*tensoriter_scalar_get(it), scalar);
} while(tensoriter_scalar_next(it));
}

25
tensoriterator.h Normal file
View File

@@ -0,0 +1,25 @@
#ifndef _TENSORITERATOR_H_
#define _TENSORITERATOR_H_
#include "tensor.h"
typedef struct _tensor_scalar_iterator {
dtype *current;
uint32_t length;
} * tensoriter_scalar;
tensoriter_scalar tensoriter_scalar_create(tensor t);
void tensoriter_scalar_reset(tensoriter_scalar it);
void tensoriter_scalar_destroy(tensoriter_scalar it);
uint8_t tensoriter_scalar_next(tensoriter_scalar it);
dtype *tensoriter_scalar_get(tensoriter_scalar it);
void tensoriter_scalar_map(tensoriter_scalar it, dtype (*func)(dtype));
void tensoriter_scalar_map_add(tensoriter_scalar it, dtype scalar);
void tensoriter_scalar_map_sub(tensoriter_scalar it, dtype scalar);
void tensoriter_scalar_map_mul(tensoriter_scalar it, dtype scalar);
void tensoriter_scalar_map_div(tensoriter_scalar it, dtype scalar);
#endif // _TENSORITERATOR_H_