From 85fd327052c477486cce793797a02f1bdc72943f Mon Sep 17 00:00:00 2001 From: franksim Date: Sun, 9 Nov 2025 12:33:02 +0100 Subject: [PATCH] adapts dataset for test set --- mmp/a4/dataset.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mmp/a4/dataset.py b/mmp/a4/dataset.py index 5d466ca..e591f2e 100644 --- a/mmp/a4/dataset.py +++ b/mmp/a4/dataset.py @@ -40,10 +40,13 @@ class MMP_Dataset(torch.utils.data.Dataset): match = img_pattern.match(fname) if match: img_file = os.path.join(path_to_data, fname) - annotations = read_groundtruth_file( - os.path.join(path_to_data, f"{match.group(1)}.gt_data.txt") - ) - self.images.append((img_file, annotations)) + if is_test: + self.images.append((img_file, None)) + else: + annotations = read_groundtruth_file( + os.path.join(path_to_data, f"{match.group(1)}.gt_data.txt") + ) + self.images.append((img_file, annotations)) self.images.sort( key=lambda x: int(re.match(r"(.*/)(\d+)(\.jpg)", x[0]).group(2)) @@ -66,10 +69,12 @@ class MMP_Dataset(torch.utils.data.Dataset): ] ) img_tensor = transform(img) + img_id = re.match(r".*(\/)([0-9]+)(\.[^\/]*$)", self.images[idx][0]).group(2) + if self.is_test: + return (img_tensor, torch.Tensor(), int(img_id)) label_grid = get_label_grid( anchor_grid=self.anchor_grid, gts=self.images[idx][1], min_iou=self.min_iou ) - img_id = re.match(r".*(\/)([0-9]+)(\.[^\/]*$)", self.images[idx][0]).group(2) return (img_tensor, label_grid, int(img_id)) def __len__(self) -> int: