Added function for array comparison. Fixed mistake in Makefile.

This commit is contained in:
2023-09-07 19:39:35 +02:00
parent f942ef25c2
commit 6f320f593c
5 changed files with 62 additions and 30 deletions

View File

@@ -46,15 +46,9 @@ bool tensor_is_equal(const tensor t1, const tensor t2)
assert(!tensor_is_empty(t1));
assert(!tensor_is_empty(t2));
uint32_t i;
if (t1->rank != t2->rank) return false;
for (i = 0; i < t1->rank; i++) {
if (t1->size[i] != t2->size[i]) return false;
}
for (i = 0; i < t1->num_elem; i++) {
if (DTYPE_NE(t1->elements[i], t2->elements[i])) return false;
}
return true;
if (!tarray_uint32_equals(t1->size, t2->size, t1->rank)) return false;
return tarray_equals(t1->elements, t2->elements, t1->num_elem);
}
bool _tensor_check_size(const uint32_t *size, uint8_t rank)
@@ -287,9 +281,7 @@ bool tensor_add_inplace(tensor t1, const tensor t2)
uint32_t i;
if(t1->rank != t2->rank) return false;
for(i = 0; i < t1->rank; i++) {
if(t1->size[i] != t2->size[i]) return false;
}
if(!tarray_uint32_equals(t1->size, t2->size, t1->rank)) return false;
for(i = 0; i < t1->num_elem; i++) {
t1->elements[i] = DTYPE_ADD(t1->elements[i], t2->elements[i]);
}
@@ -312,9 +304,7 @@ bool tensor_sub_inplace(tensor t1, const tensor t2)
uint32_t i;
if(t1->rank != t2->rank) return false;
for(i = 0; i < t1->rank; i++) {
if(t1->size[i] != t2->size[i]) return false;
}
if(!tarray_uint32_equals(t1->size, t2->size, t1->rank)) return false;
for(i = 0; i < t1->num_elem; i++) {
t1->elements[i] = DTYPE_SUB(t1->elements[i], t2->elements[i]);
}
@@ -337,9 +327,7 @@ bool tensor_mul_inplace(tensor t1, const tensor t2)
uint32_t i;
if(t1->rank != t2->rank) return false;
for(i = 0; i < t1->rank; i++) {
if(t1->size[i] != t2->size[i]) return false;
}
if(!tarray_uint32_equals(t1->size, t2->size, t1->rank)) return false;
for(i = 0; i < t1->num_elem; i++) {
t1->elements[i] = DTYPE_MUL(t1->elements[i], t2->elements[i]);
}
@@ -362,9 +350,7 @@ bool tensor_div_inplace(tensor t1, const tensor t2)
uint32_t i;
if(t1->rank != t2->rank) return false;
for(i = 0; i < t1->rank; i++) {
if(t1->size[i] != t2->size[i]) return false;
}
if(!tarray_uint32_equals(t1->size, t2->size, t1->rank)) return false;
for(i = 0; i < t1->num_elem; i++) {
t1->elements[i] = DTYPE_DIV(t1->elements[i], t2->elements[i]);
}