Minor bug fixes.

This commit is contained in:
2023-03-09 18:24:17 +01:00
parent dbf11ef652
commit 7d22af5e85

View File

@@ -40,10 +40,10 @@ int _tensor_set_size(tensor t, const int *size, int dim)
} }
if(!_tensor_check_size(size, dim)) return 0; if(!_tensor_check_size(size, dim)) return 0;
/* Try allocating memory for the size array of the Tensor */ /* Try allocating memory for the size array of the tensor */
temp = realloc(t->size, dim * sizeof(int)); temp = realloc(t->size, dim * sizeof(int));
if(temp == NULL && dim != 0) return 0; if(temp == NULL && dim != 0) return 0;
/* Try allocating memory for the Tensor */ /* Try allocating memory for the tensor */
t_temp = realloc(t->elements, num_elem * sizeof(dtype)); t_temp = realloc(t->elements, num_elem * sizeof(dtype));
if(t_temp == NULL) { if(t_temp == NULL) {
/* Revert to before the function call and return */ /* Revert to before the function call and return */
@@ -70,7 +70,10 @@ int tensor_set(tensor t, const int *index, dtype val)
{ {
int i, offset = 0; int i, offset = 0;
if(tensor_is_empty(t)) return 0; if(tensor_is_empty(t)) return 0;
if(t->dimension == 0) return t->elements[0] = val; if(t->dimension == 0) {
t->elements[0] = val;
return 1;
}
for(i = 0; i < t->dimension - 1; i++) { for(i = 0; i < t->dimension - 1; i++) {
if(t->size[i] <= index[i]) return 0; if(t->size[i] <= index[i]) return 0;
@@ -181,11 +184,11 @@ void tensor_print(const tensor t)
if(t->dimension == 0) { if(t->dimension == 0) {
/* Skalar */ /* scalar */
printf(PRINT_STRING, t->elements[0]); printf(PRINT_STRING, t->elements[0]);
putchar('\n'); putchar('\n');
} else if (t->dimension == 1) { } else if (t->dimension == 1) {
/* Spaltenvektor */ /* column vector */
if(t->size[0] == 1) { if(t->size[0] == 1) {
putchar('('); putchar('(');
printf(PRINT_STRING, t->elements[0]); printf(PRINT_STRING, t->elements[0]);
@@ -204,7 +207,7 @@ void tensor_print(const tensor t)
printf("/\n"); printf("/\n");
} }
} else if (t->dimension == 2) { } else if (t->dimension == 2) {
/* Matix */ /* matix */
indx = malloc(sizeof(int) * 2); indx = malloc(sizeof(int) * 2);
if(t->size[0] == 1) { if(t->size[0] == 1) {
putchar('('); putchar('(');