Fixed major bug when indexing elements.

This commit is contained in:
2023-03-12 21:29:10 +01:00
parent 35b866f1c9
commit 0b7ba4fa52

View File

@@ -82,41 +82,58 @@ int _tensor_set_size(tensor t, const int *size, int dim)
int tensor_set(tensor t, const int *index, dtype val) int tensor_set(tensor t, const int *index, dtype val)
{ {
int i, offset = 0; int i, j, offset = 0;
int *size_offset = malloc(t->dimension * sizeof(int));
/* TODO free on error */
if(size_offset == NULL) return 0;
for(i = 0; i < t->dimension; i++) {
size_offset[i] = 1;
for(j = i + 1; j < t->dimension; j++) {
size_offset[i] *= t->size[j];
}
}
if(tensor_is_empty(t)) return 0; if(tensor_is_empty(t)) return 0;
if(t->dimension == 0) { if(t->dimension == 0) {
t->elements[0] = val; t->elements[0] = val;
return 1; return 1;
} }
for(i = 0; i < t->dimension - 1; i++) { for(i = 0; i < t->dimension; i++) {
if(t->size[i] <= index[i]) return 0; if(t->size[i] <= index[i]) return 0;
offset += t->size[i + 1] * index[i]; offset += size_offset[i] * index[i];
} }
if(t->size[t->dimension - 1] <= index[t->dimension - 1]) return 0;
offset += index[t->dimension - 1];
t->elements[offset] = val; t->elements[offset] = val;
free(size_offset);
return 1; return 1;
} }
dtype tensor_get(const tensor t, const int *index, int *success) dtype tensor_get(const tensor t, const int *index, int *success)
{ {
int i, offset = 0; int i, j, offset = 0;
int *size_offset = malloc(t->dimension * sizeof(int));
/* TODO free on error */
if(size_offset == NULL) return 0;
for(i = 0; i < t->dimension; i++) {
size_offset[i] = 1;
for(j = i + 1; j < t->dimension; j++) {
size_offset[i] *= t->size[j];
}
}
if(tensor_is_empty(t)) return 0; if(tensor_is_empty(t)) return 0;
if(t->dimension == 0) return t->elements[0]; if(t->dimension == 0) return t->elements[0];
for(i = 0; i < t->dimension - 1; i++) { for(i = 0; i < t->dimension; i++) {
if(t->size[i] <= index[i]) { if(t->size[i] <= index[i]) {
if(success != NULL) *success = 0; if(success != NULL) *success = 0;
return 0; return 0;
} }
offset += t->size[i + 1] * index[i]; offset += size_offset[i] * index[i];
} }
if(t->size[t->dimension - 1] <= index[t->dimension - 1]) {
if(success != NULL) *success = 0;
return 0;
}
offset += index[t->dimension - 1];
if(success != NULL) *success = 1; if(success != NULL) *success = 1;
return t->elements[offset]; return t->elements[offset];