formatting

This commit is contained in:
franksim
2025-11-07 11:20:08 +01:00
parent 8fc3559d6c
commit b159d76517
8 changed files with 119 additions and 82 deletions

View File

@@ -11,7 +11,7 @@ def pad_to_square(img):
max_wh = max(w, h)
pad = ((max_wh - w) // 2, (max_wh - h) // 2)
padding = (pad[0], pad[1], max_wh - w - pad[0], max_wh - h - pad[1])
return F.pad(img, padding, fill=0, padding_mode='constant')
return F.pad(img, padding, fill=0, padding_mode="constant")
def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor:
@@ -21,17 +21,18 @@ def build_batch(paths: Sequence[str], transform=None) -> torch.Tensor:
@param transform: One or multiple image transformations for augmenting the batch images.
@return: Returns one single tensor that contains every image.
"""
preprocess = transforms.Compose([
transforms.Lambda(pad_to_square),
transforms.Resize((224, 224)),
*([transform] if transform is not None else []),
transforms.ToTensor()
]
preprocess = transforms.Compose(
[
transforms.Lambda(pad_to_square),
transforms.Resize((224, 224)),
*([transform] if transform is not None else []),
transforms.ToTensor(),
]
)
imgs = []
for path in paths:
img = Image.open(path).convert('RGB')
img = Image.open(path).convert("RGB")
img = preprocess(img)
imgs.append(img)
batch = torch.stack(imgs)
@@ -43,8 +44,7 @@ def get_model() -> torch.nn.Module:
@return: Returns a neural network, initialised with pretrained weights.
"""
model = models.resnet18(
weights=models.ResNet18_Weights.DEFAULT)
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
return model

View File

@@ -5,7 +5,12 @@ def avg_color(img: torch.Tensor):
return img.mean(dim=(1, 2))
def mask(foreground: torch.Tensor, background: torch.Tensor, mask_tensor: torch.Tensor, threshold: float):
def mask(
foreground: torch.Tensor,
background: torch.Tensor,
mask_tensor: torch.Tensor,
threshold: float,
):
mask = mask_tensor > threshold
if foreground.dim() == 3:
mask = mask.unsqueeze(0)