formatting

This commit is contained in:
franksim
2025-11-07 11:20:08 +01:00
parent 8fc3559d6c
commit b159d76517
8 changed files with 119 additions and 82 deletions

View File

@@ -11,7 +11,7 @@ def pad_to_square(img):
max_wh = max(w, h)
pad = ((max_wh - w) // 2, (max_wh - h) // 2)
padding = (pad[0], pad[1], max_wh - w - pad[0], max_wh - h - pad[1])
return F.pad(img, padding, fill=0, padding_mode='constant')
return F.pad(img, padding, fill=0, padding_mode="constant")
def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor:
@@ -21,17 +21,18 @@ def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor:
@param transform: One or multiple image transformations for augmenting the batch images.
@return: Returns one single tensor that contains every image.
"""
preprocess = transforms.Compose([
transforms.Lambda(pad_to_square),
transforms.Resize((224, 224)),
*([transform] if transform is not None else []),
transforms.ToTensor()
]
preprocess = transforms.Compose(
[
transforms.Lambda(pad_to_square),
transforms.Resize((224, 224)),
*([transform] if transform is not None else []),
transforms.ToTensor(),
]
)
imgs = []
for path in paths:
img = Image.open(path).convert('RGB')
img = Image.open(path).convert("RGB")
img = preprocess(img)
imgs.append(img)
batch = torch.stack(imgs)
@@ -43,8 +44,7 @@ def get_model() -> torch.nn.Module:
@return: Returns a neural network, initialised with pretrained weights.
"""
model = models.resnet18(
weights=models.ResNet18_Weights.DEFAULT)
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
return model