diff --git a/tensor.c b/tensor.c index 859ab4e..8ae0e48 100644 --- a/tensor.c +++ b/tensor.c @@ -15,25 +15,25 @@ void tensor_destroy(tensor t) free(t); } -uint8_t tensor_is_empty(const tensor t) +bool tensor_is_empty(const tensor t) { return t->elements == NULL || t->size == NULL; } -uint8_t tensor_is_equal(const tensor t1, const tensor t2) +bool tensor_is_equal(const tensor t1, const tensor t2) { assert(!tensor_is_empty(t1)); assert(!tensor_is_empty(t2)); uint32_t i; - if (t1->rank != t2->rank) return 0; + if (t1->rank != t2->rank) return false; for (i = 0; i < t1->rank; i++) { - if (t1->size[i] != t2->size[i]) return 0; + if (t1->size[i] != t2->size[i]) return false; } for (i = 0; i < t1->num_elem; i++) { - if (DTYPE_NE(t1->elements[i], t2->elements[i])) return 0; + if (DTYPE_NE(t1->elements[i], t2->elements[i])) return false; } - return 1; + return true; } uint8_t _tensor_check_size(const uint32_t *size, uint8_t rank) diff --git a/tensor.h b/tensor.h index 841fb53..6331a54 100644 --- a/tensor.h +++ b/tensor.h @@ -6,8 +6,8 @@ #include #include #include -#include #include +#include /* Defining the datatype of the tensor */ typedef float dtype; @@ -39,8 +39,8 @@ typedef struct _tensor { tensor tensor_new(void); void tensor_destroy(tensor t); -uint8_t tensor_is_empty(const tensor t); -uint8_t tensor_is_equal(const tensor t1, const tensor t2); +bool tensor_is_empty(const tensor t); +bool tensor_is_equal(const tensor t1, const tensor t2); uint8_t _tensor_check_size(const uint32_t *size, uint8_t rank); uint8_t _tensor_set_size(tensor t, const uint32_t *size, uint8_t rank);