tensor_is_equal and tensor_cpy added.

This commit is contained in:
2023-03-12 20:29:55 +01:00
parent 7d22af5e85
commit 35b866f1c9
2 changed files with 27 additions and 1 deletions

View File

@@ -15,10 +15,24 @@ void tensor_destroy(tensor t)
} }
} }
int tensor_is_empty(const tensor t){ int tensor_is_empty(const tensor t)
{
return t->elements == NULL || t->size == NULL; return t->elements == NULL || t->size == NULL;
} }
int tensor_is_equal(const tensor t1, const tensor t2)
{
int i;
if (t1->dimension != t2->dimension) return 0;
for (i = 0; i < t1->dimension; i++) {
if (t1->size[i] != t2->size[i]) return 0;
}
for (i = 0; i < t1->num_elem; i++) {
if (t1->elements[i] != t2->elements[i]) return 0;
}
return 1;
}
int _tensor_check_size(const int *size, int dim) int _tensor_check_size(const int *size, int dim)
{ {
int i; int i;
@@ -165,6 +179,16 @@ void tensor_for_each_elem(tensor t, dtype (*func)(dtype))
} }
} }
int tensor_cpy(tensor t1, const tensor t2)
{
int i;
if(!_tensor_set_size(t1, t2->size, t2->dimension)) return 0;
for(i = 0; i < t2->num_elem; i++) {
t1->elements[i] = t2->elements[i];
}
return 1;
}
void tensor_print(const tensor t) void tensor_print(const tensor t)
{ {
int i, j; int i, j;

View File

@@ -22,6 +22,7 @@ tensor tensor_new(void);
void tensor_destroy(tensor t); void tensor_destroy(tensor t);
int tensor_is_empty(const tensor t); int tensor_is_empty(const tensor t);
int tensor_is_equal(const tensor t1, const tensor t2);
int _tensor_check_size(const int *size, int dim); int _tensor_check_size(const int *size, int dim);
int _tensor_set_size(tensor t, const int *size, int dim); int _tensor_set_size(tensor t, const int *size, int dim);
@@ -36,6 +37,7 @@ int tensor_init_rand(tensor t, int dimension, const int *size, int max);
int tensor_add(tensor t1, const tensor t2); int tensor_add(tensor t1, const tensor t2);
void tensor_for_each_elem(tensor t, dtype (*func)(dtype)); void tensor_for_each_elem(tensor t, dtype (*func)(dtype));
int tensor_cpy(tensor t1, const tensor t2);
void tensor_print(const tensor t); void tensor_print(const tensor t);
#endif #endif