Changed implementation of scalar operations to use map.

This commit is contained in:
2023-04-11 20:39:14 +02:00
parent f7b8589b87
commit 721ac75389
2 changed files with 23 additions and 18 deletions

View File

@@ -1,6 +1,10 @@
#include "tensor.h"
#include <stdio.h>
/* A helper for f.e. the tensor_add_scalar function */
dtype _tensor_scalar_wise_helper;
tensor tensor_new(void)
{
return calloc(1, sizeof(struct _tensor));
@@ -183,44 +187,42 @@ int tensor_cpy(tensor t1, const tensor t2)
return 1;
}
dtype _tensor_dtype_add_helper(dtype d) { return DTYPE_ADD(d, _tensor_scalar_wise_helper); }
dtype _tensor_dtype_sub_helper(dtype d) { return DTYPE_SUB(d, _tensor_scalar_wise_helper); }
dtype _tensor_dtype_mul_helper(dtype d) { return DTYPE_MUL(d, _tensor_scalar_wise_helper); }
dtype _tensor_dtype_div_helper(dtype d) { return DTYPE_DIV(d, _tensor_scalar_wise_helper); }
void tensor_add_scalar(tensor t, dtype n)
{
assert(!tensor_is_empty(t));
int i;
for(i = 0; i < t->num_elem; i++) {
t->elements[i] = DTYPE_ADD(t->elements[i], n);
}
_tensor_scalar_wise_helper = n;
tensor_map(t, &_tensor_dtype_add_helper);
}
void tensor_sub_scalar(tensor t, dtype n)
{
assert(!tensor_is_empty(t));
int i;
for(i = 0; i < t->num_elem; i++) {
t->elements[i] = DTYPE_SUB(t->elements[i], n);
}
_tensor_scalar_wise_helper = n;
tensor_map(t, &_tensor_dtype_sub_helper);
}
void tensor_mult_scalar(tensor t, dtype n)
void tensor_mul_scalar(tensor t, dtype n)
{
assert(!tensor_is_empty(t));
int i;
for(i = 0; i < t->num_elem; i++) {
t->elements[i] = DTYPE_MUL(t->elements[i], n);
}
_tensor_scalar_wise_helper = n;
tensor_map(t, &_tensor_dtype_mul_helper);
}
void tensor_div_scalar(tensor t, dtype n)
{
assert(!tensor_is_empty(t));
int i;
for(i = 0; i < t->num_elem; i++) {
t->elements[i] = DTYPE_DIV(t->elements[i], n);
}
_tensor_scalar_wise_helper = n;
tensor_map(t, &_tensor_dtype_div_helper);
}
int tensor_add_inplace(tensor t1, const tensor t2)