formatting
This commit is contained in:
@@ -11,7 +11,7 @@ def pad_to_square(img):
|
||||
max_wh = max(w, h)
|
||||
pad = ((max_wh - w) // 2, (max_wh - h) // 2)
|
||||
padding = (pad[0], pad[1], max_wh - w - pad[0], max_wh - h - pad[1])
|
||||
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||
return F.pad(img, padding, fill=0, padding_mode="constant")
|
||||
|
||||
|
||||
def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor:
|
||||
@@ -21,17 +21,18 @@ def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor:
|
||||
@param transform: One or multiple image transformations for augmenting the batch images.
|
||||
@return: Returns one single tensor that contains every image.
|
||||
"""
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Lambda(pad_to_square),
|
||||
transforms.Resize((224, 224)),
|
||||
*([transform] if transform is not None else []),
|
||||
transforms.ToTensor()
|
||||
]
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(pad_to_square),
|
||||
transforms.Resize((224, 224)),
|
||||
*([transform] if transform is not None else []),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
imgs = []
|
||||
|
||||
for path in paths:
|
||||
img = Image.open(path).convert('RGB')
|
||||
img = Image.open(path).convert("RGB")
|
||||
img = preprocess(img)
|
||||
imgs.append(img)
|
||||
batch = torch.stack(imgs)
|
||||
@@ -43,8 +44,7 @@ def get_model() -> torch.nn.Module:
|
||||
|
||||
@return: Returns a neural network, initialised with pretrained weights.
|
||||
"""
|
||||
model = models.resnet18(
|
||||
weights=models.ResNet18_Weights.DEFAULT)
|
||||
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user