Files
mmp_wise2526_franksim/mmp/a5/main.py
2025-11-18 09:14:23 +01:00

273 lines
8.2 KiB
Python

import argparse
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torch import Tensor
from tqdm import tqdm
import datetime
from .model import MmpNet
from ..a4.anchor_grid import get_anchor_grid
from ..a4.dataset import get_dataloader
from ..a2.main import get_criterion_optimizer
def step(
model: MmpNet,
criterion,
optimizer: optim.Optimizer,
img_batch: torch.Tensor,
lbl_batch: torch.Tensor,
) -> float:
model.train()
optimizer.zero_grad()
device = next(model.parameters()).device
img_batch = img_batch.to(device)
lbl_batch = lbl_batch.to(device)
outputs = model(img_batch)
loss = criterion(outputs, lbl_batch)
loss.backward()
optimizer.step()
return loss.item()
def get_random_sampling_mask(labels: torch.Tensor, neg_ratio: float) -> torch.Tensor:
"""
@param labels: The label tensor that is returned by your data loader.
The values are either 0 (negative label) or 1 (positive label).
@param neg_ratio: The desired negative/positive ratio.
Hint: after computing the mask, check if the neg_ratio is fulfilled.
@return: A tensor with the same shape as labels
"""
# Flatten for easier indexing
labels_flat = labels.view(-1)
pos_indices = (labels_flat == 1).nonzero(as_tuple=True)[0]
neg_indices = (labels_flat == 0).nonzero(as_tuple=True)[0]
num_pos = pos_indices.numel()
num_neg = neg_indices.numel()
num_neg_to_sample = min(int(neg_ratio * num_pos), num_neg)
perm = torch.randperm(num_neg, device=labels.device)
sampled_neg_indices = neg_indices[perm[:num_neg_to_sample]]
mask_flat = torch.zeros_like(labels_flat, dtype=torch.long)
mask_flat[pos_indices] = 1
mask_flat[sampled_neg_indices] = 1
# Reshape to original shape
mask = mask_flat.view_as(labels)
return mask
def get_detection_metrics(
output: Tensor, labels: torch.Tensor, threshold: float
) -> tuple[float, float, float, float]:
"""
Returns precision, recall, f1 for the positive (human) class, and overall accuracy.
"""
with torch.no_grad():
probs = torch.softmax(output, dim=-1)[..., 1]
preds = probs >= threshold
TP = ((preds == 1) & (labels == 1)).sum().item()
FP = ((preds == 1) & (labels == 0)).sum().item()
FN = ((preds == 0) & (labels == 1)).sum().item()
TN = ((preds == 0) & (labels == 0)).sum().item()
precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0.0
)
accuracy = (TP + TN) / (TP + TN + FP + FN) if (TP + TN + FP + FN) > 0 else 0.0
return (
precision,
recall,
f1,
accuracy,
)
def evaluate(
model: MmpNet,
criterion,
dataloader: DataLoader,
) -> tuple[float, float, float, float]:
device = next(model.parameters()).device
model.eval()
total_loss = 0.0
total_samples = 0
all_outputs = []
all_labels = []
with torch.no_grad():
for img_batch, lbl_batch, _ in dataloader:
img_batch = img_batch.to(device)
lbl_batch = lbl_batch.to(device)
outputs = model(img_batch)
loss = criterion(outputs, lbl_batch)
batch_size = img_batch.size(0)
total_loss += loss.item() * batch_size
total_samples += batch_size
all_outputs.append(outputs.cpu())
all_labels.append(lbl_batch.cpu())
avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
if all_outputs and all_labels:
outputs_cat = torch.cat(all_outputs)
labels_cat = torch.cat(all_labels)
precision, recall, f1, acc = get_detection_metrics(
outputs_cat, labels_cat, threshold=0.5
)
else:
precision = recall = f1 = 0.0
return avg_loss, precision, recall, f1, acc
def train(
model: MmpNet,
loader: DataLoader,
criterion: nn.Module,
optimizer: optim.Optimizer,
):
model.train()
running_loss = 0.0
total_samples = 0
progress_bar = tqdm(loader, desc="Training", unit="batch")
for img_batch, lbl_batch, _ in progress_bar:
loss = step(
model=model,
criterion=criterion,
optimizer=optimizer,
img_batch=img_batch,
lbl_batch=lbl_batch,
)
batch_size = img_batch.size(0)
running_loss += loss * batch_size
total_samples += batch_size
progress_bar.set_postfix(
{"loss": running_loss / total_samples if total_samples > 0 else 0.0}
)
epoch_loss = running_loss / total_samples if total_samples > 0 else 0.0
progress_bar.close()
return epoch_loss
class NegativeMiningCriterion(nn.Module):
def __init__(self, neg_ratio=3.0, enable_negative_mining: bool = True):
super().__init__()
self.backbone = nn.CrossEntropyLoss(reduction="none")
self.neg_ratio = neg_ratio
self.enable_negative_mining = enable_negative_mining
def forward(self, outputs, labels):
outputs_flat = outputs.view(-1, outputs.shape[-1])
labels_flat = labels.view(-1).long()
unfiltered = self.backbone(outputs_flat, labels_flat)
assert unfiltered.shape == labels_flat.shape
if not self.enable_negative_mining:
return unfiltered.mean()
mask = get_random_sampling_mask(labels_flat, self.neg_ratio)
filtered_loss = unfiltered[mask == 1]
return filtered_loss.mean()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--tensorboard",
nargs="?",
const=True,
default=False,
help="Enable TensorBoard logging. If a label is provided, it will be used in the log directory name.",
)
args = parser.parse_args()
if args.tensorboard:
from torch.utils.tensorboard import SummaryWriter
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
if isinstance(args.tensorboard, str):
label = args.tensorboard
log_dir = f"runs/a5_mmpnet_{label}_{timestamp}"
else:
log_dir = f"runs/a5_mmpnet_{timestamp}"
writer = SummaryWriter(log_dir=log_dir)
else:
writer = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MmpNet(num_aspect_ratios=8, num_widths=8).to(device)
anchor_grid = get_anchor_grid(
anchor_widths=[8, 16, 32, 64, 96, 128, 160, 192],
aspect_ratios=[1 / 2, 2 / 3, 1, 4 / 3, 5 / 3, 2, 2.5, 3],
num_rows=7,
num_cols=7,
scale_factor=32,
)
dataloader_train = get_dataloader(
path_to_data=".data/mmp-public-3.2/train",
image_size=224,
batch_size=32,
num_workers=9,
is_test=False,
is_train=True,
anchor_grid=anchor_grid,
)
dataloader_val = get_dataloader(
path_to_data=".data/mmp-public-3.2/val",
image_size=224,
batch_size=32,
num_workers=9,
is_test=False,
is_train=False,
anchor_grid=anchor_grid,
)
_, optimizer = get_criterion_optimizer(model=model)
criterion = NegativeMiningCriterion(enable_negative_mining=True)
criterion_eval = NegativeMiningCriterion(enable_negative_mining=False)
num_epochs = 10
for epoch in range(num_epochs):
train_loss = train(
model=model,
loader=dataloader_train,
criterion=criterion,
optimizer=optimizer,
)
avg_loss, precision, recall, f1, acc = evaluate(
model=model, criterion=criterion_eval, dataloader=dataloader_val
)
if writer is not None:
writer.add_scalar("Loss/train_epoch", train_loss, epoch)
writer.add_scalar("Loss/eval_epoch", avg_loss, epoch)
writer.add_scalar("Acc/precision", precision, epoch)
writer.add_scalar("Acc/recall", recall, epoch)
writer.add_scalar("Acc/acc", acc, epoch)
writer.add_scalar("Acc/f1", f1, epoch)
if writer is not None:
writer.close()
if __name__ == "__main__":
main()