adapts dataset for test set

This commit is contained in:
franksim
2025-11-09 12:33:02 +01:00
parent 51b18f1d82
commit 85fd327052

View File

@@ -40,6 +40,9 @@ class MMP_Dataset(torch.utils.data.Dataset):
match = img_pattern.match(fname)
if match:
img_file = os.path.join(path_to_data, fname)
if is_test:
self.images.append((img_file, None))
else:
annotations = read_groundtruth_file(
os.path.join(path_to_data, f"{match.group(1)}.gt_data.txt")
)
@@ -66,10 +69,12 @@ class MMP_Dataset(torch.utils.data.Dataset):
]
)
img_tensor = transform(img)
img_id = re.match(r".*(\/)([0-9]+)(\.[^\/]*$)", self.images[idx][0]).group(2)
if self.is_test:
return (img_tensor, torch.Tensor(), int(img_id))
label_grid = get_label_grid(
anchor_grid=self.anchor_grid, gts=self.images[idx][1], min_iou=self.min_iou
)
img_id = re.match(r".*(\/)([0-9]+)(\.[^\/]*$)", self.images[idx][0]).group(2)
return (img_tensor, label_grid, int(img_id))
def __len__(self) -> int: