Added tensor subtraction.

This commit is contained in:
2023-03-18 20:56:40 +01:00
parent 52d240fa98
commit 0ef0eeb047
2 changed files with 17 additions and 0 deletions

View File

@@ -238,6 +238,22 @@ int tensor_add(tensor t1, const tensor t2)
return 1; return 1;
} }
int tensor_sub(tensor t1, const tensor t2)
{
assert(!tensor_is_empty(t1));
assert(!tensor_is_empty(t2));
int i;
if(t1->dimension != t2->dimension) return 0;
for(i = 0; i < t1->dimension; i++) {
if(t1->size[i] != t2->size[i]) return 0;
}
for(i = 0; i < t1->num_elem; i++) {
t1->elements[i] -= t2->elements[i];
}
return 1;
}
void tensor_for_each_elem(tensor t, dtype (*func)(dtype)) void tensor_for_each_elem(tensor t, dtype (*func)(dtype))
{ {
assert(!tensor_is_empty(t)); assert(!tensor_is_empty(t));

View File

@@ -42,6 +42,7 @@ void tensor_sub_scalar(tensor t, dtype n);
void tensor_mult_scalar(tensor t, dtype n); void tensor_mult_scalar(tensor t, dtype n);
void tensor_div_scalar(tensor t, dtype n); void tensor_div_scalar(tensor t, dtype n);
int tensor_add(tensor t1, const tensor t2); int tensor_add(tensor t1, const tensor t2);
int tensor_sub(tensor t1, const tensor t2);
void tensor_for_each_elem(tensor t, dtype (*func)(dtype)); void tensor_for_each_elem(tensor t, dtype (*func)(dtype));
void tensor_print(const tensor t); void tensor_print(const tensor t);