adapts dataset for test set
This commit is contained in:
@@ -40,10 +40,13 @@ class MMP_Dataset(torch.utils.data.Dataset):
|
|||||||
match = img_pattern.match(fname)
|
match = img_pattern.match(fname)
|
||||||
if match:
|
if match:
|
||||||
img_file = os.path.join(path_to_data, fname)
|
img_file = os.path.join(path_to_data, fname)
|
||||||
annotations = read_groundtruth_file(
|
if is_test:
|
||||||
os.path.join(path_to_data, f"{match.group(1)}.gt_data.txt")
|
self.images.append((img_file, None))
|
||||||
)
|
else:
|
||||||
self.images.append((img_file, annotations))
|
annotations = read_groundtruth_file(
|
||||||
|
os.path.join(path_to_data, f"{match.group(1)}.gt_data.txt")
|
||||||
|
)
|
||||||
|
self.images.append((img_file, annotations))
|
||||||
|
|
||||||
self.images.sort(
|
self.images.sort(
|
||||||
key=lambda x: int(re.match(r"(.*/)(\d+)(\.jpg)", x[0]).group(2))
|
key=lambda x: int(re.match(r"(.*/)(\d+)(\.jpg)", x[0]).group(2))
|
||||||
@@ -66,10 +69,12 @@ class MMP_Dataset(torch.utils.data.Dataset):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
img_tensor = transform(img)
|
img_tensor = transform(img)
|
||||||
|
img_id = re.match(r".*(\/)([0-9]+)(\.[^\/]*$)", self.images[idx][0]).group(2)
|
||||||
|
if self.is_test:
|
||||||
|
return (img_tensor, torch.Tensor(), int(img_id))
|
||||||
label_grid = get_label_grid(
|
label_grid = get_label_grid(
|
||||||
anchor_grid=self.anchor_grid, gts=self.images[idx][1], min_iou=self.min_iou
|
anchor_grid=self.anchor_grid, gts=self.images[idx][1], min_iou=self.min_iou
|
||||||
)
|
)
|
||||||
img_id = re.match(r".*(\/)([0-9]+)(\.[^\/]*$)", self.images[idx][0]).group(2)
|
|
||||||
return (img_tensor, label_grid, int(img_id))
|
return (img_tensor, label_grid, int(img_id))
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|||||||
Reference in New Issue
Block a user