adapts dataset for test set
This commit is contained in:
@@ -40,10 +40,13 @@ class MMP_Dataset(torch.utils.data.Dataset):
|
||||
match = img_pattern.match(fname)
|
||||
if match:
|
||||
img_file = os.path.join(path_to_data, fname)
|
||||
annotations = read_groundtruth_file(
|
||||
os.path.join(path_to_data, f"{match.group(1)}.gt_data.txt")
|
||||
)
|
||||
self.images.append((img_file, annotations))
|
||||
if is_test:
|
||||
self.images.append((img_file, None))
|
||||
else:
|
||||
annotations = read_groundtruth_file(
|
||||
os.path.join(path_to_data, f"{match.group(1)}.gt_data.txt")
|
||||
)
|
||||
self.images.append((img_file, annotations))
|
||||
|
||||
self.images.sort(
|
||||
key=lambda x: int(re.match(r"(.*/)(\d+)(\.jpg)", x[0]).group(2))
|
||||
@@ -66,10 +69,12 @@ class MMP_Dataset(torch.utils.data.Dataset):
|
||||
]
|
||||
)
|
||||
img_tensor = transform(img)
|
||||
img_id = re.match(r".*(\/)([0-9]+)(\.[^\/]*$)", self.images[idx][0]).group(2)
|
||||
if self.is_test:
|
||||
return (img_tensor, torch.Tensor(), int(img_id))
|
||||
label_grid = get_label_grid(
|
||||
anchor_grid=self.anchor_grid, gts=self.images[idx][1], min_iou=self.min_iou
|
||||
)
|
||||
img_id = re.match(r".*(\/)([0-9]+)(\.[^\/]*$)", self.images[idx][0]).group(2)
|
||||
return (img_tensor, label_grid, int(img_id))
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
||||
Reference in New Issue
Block a user