Files
mmp_wise2526_franksim/mmp/a5/main.py

231 lines
6.7 KiB
Python
Raw Normal View History

2025-11-16 16:28:13 +01:00
import argparse
2025-10-13 14:48:00 +02:00
import torch
import torch.optim as optim
2025-11-16 16:28:13 +01:00
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import datetime
2025-10-13 14:48:00 +02:00
from .model import MmpNet
2025-11-16 16:28:13 +01:00
from ..a4.anchor_grid import get_anchor_grid
from ..a4.dataset import get_dataloader
from ..a2.main import get_criterion_optimizer
2025-12-02 11:04:47 +01:00
from ..a6.main import evaluate as evaluate_v2
2025-10-13 14:48:00 +02:00
def step(
model: MmpNet,
criterion,
optimizer: optim.Optimizer,
img_batch: torch.Tensor,
lbl_batch: torch.Tensor,
) -> float:
2025-11-16 16:28:13 +01:00
model.train()
optimizer.zero_grad()
2025-10-13 14:48:00 +02:00
2025-11-16 16:28:13 +01:00
device = next(model.parameters()).device
img_batch = img_batch.to(device)
lbl_batch = lbl_batch.to(device)
outputs = model(img_batch)
loss = criterion(outputs, lbl_batch)
loss.backward()
optimizer.step()
return loss.item()
2025-10-13 14:48:00 +02:00
def get_random_sampling_mask(labels: torch.Tensor, neg_ratio: float) -> torch.Tensor:
"""
@param labels: The label tensor that is returned by your data loader.
The values are either 0 (negative label) or 1 (positive label).
@param neg_ratio: The desired negative/positive ratio.
Hint: after computing the mask, check if the neg_ratio is fulfilled.
@return: A tensor with the same shape as labels
"""
2025-11-16 16:28:13 +01:00
# Flatten for easier indexing
labels_flat = labels.view(-1)
pos_indices = (labels_flat == 1).nonzero(as_tuple=True)[0]
neg_indices = (labels_flat == 0).nonzero(as_tuple=True)[0]
num_pos = pos_indices.numel()
num_neg = neg_indices.numel()
num_neg_to_sample = min(int(neg_ratio * num_pos), num_neg)
perm = torch.randperm(num_neg, device=labels.device)
sampled_neg_indices = neg_indices[perm[:num_neg_to_sample]]
mask_flat = torch.zeros_like(labels_flat, dtype=torch.long)
mask_flat[pos_indices] = 1
mask_flat[sampled_neg_indices] = 1
# Reshape to original shape
mask = mask_flat.view_as(labels)
return mask
def evaluate(
model: MmpNet,
criterion,
dataloader: DataLoader,
2025-12-02 11:04:47 +01:00
) -> float:
2025-11-16 16:28:13 +01:00
device = next(model.parameters()).device
model.eval()
total_loss = 0.0
total_samples = 0
all_outputs = []
all_labels = []
with torch.no_grad():
for img_batch, lbl_batch, _ in dataloader:
img_batch = img_batch.to(device)
lbl_batch = lbl_batch.to(device)
outputs = model(img_batch)
loss = criterion(outputs, lbl_batch)
batch_size = img_batch.size(0)
total_loss += loss.item() * batch_size
total_samples += batch_size
all_outputs.append(outputs.cpu())
all_labels.append(lbl_batch.cpu())
avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
2025-12-02 11:04:47 +01:00
return avg_loss
2025-11-16 16:28:13 +01:00
def train(
model: MmpNet,
loader: DataLoader,
criterion: nn.Module,
optimizer: optim.Optimizer,
):
model.train()
running_loss = 0.0
total_samples = 0
progress_bar = tqdm(loader, desc="Training", unit="batch")
for img_batch, lbl_batch, _ in progress_bar:
loss = step(
model=model,
criterion=criterion,
optimizer=optimizer,
img_batch=img_batch,
lbl_batch=lbl_batch,
)
batch_size = img_batch.size(0)
running_loss += loss * batch_size
total_samples += batch_size
progress_bar.set_postfix(
{"loss": running_loss / total_samples if total_samples > 0 else 0.0}
)
epoch_loss = running_loss / total_samples if total_samples > 0 else 0.0
progress_bar.close()
return epoch_loss
class NegativeMiningCriterion(nn.Module):
def __init__(self, neg_ratio=3.0, enable_negative_mining: bool = True):
super().__init__()
self.backbone = nn.CrossEntropyLoss(reduction="none")
self.neg_ratio = neg_ratio
self.enable_negative_mining = enable_negative_mining
def forward(self, outputs, labels):
outputs_flat = outputs.view(-1, outputs.shape[-1])
labels_flat = labels.view(-1).long()
unfiltered = self.backbone(outputs_flat, labels_flat)
assert unfiltered.shape == labels_flat.shape
if not self.enable_negative_mining:
return unfiltered.mean()
mask = get_random_sampling_mask(labels_flat, self.neg_ratio)
filtered_loss = unfiltered[mask == 1]
return filtered_loss.mean()
2025-10-13 14:48:00 +02:00
def main():
2025-11-16 16:28:13 +01:00
parser = argparse.ArgumentParser()
parser.add_argument(
"--tensorboard",
nargs="?",
const=True,
default=False,
help="Enable TensorBoard logging. If a label is provided, it will be used in the log directory name.",
)
args = parser.parse_args()
if args.tensorboard:
from torch.utils.tensorboard import SummaryWriter
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
if isinstance(args.tensorboard, str):
label = args.tensorboard
log_dir = f"runs/a5_mmpnet_{label}_{timestamp}"
else:
log_dir = f"runs/a5_mmpnet_{timestamp}"
writer = SummaryWriter(log_dir=log_dir)
else:
writer = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MmpNet(num_aspect_ratios=8, num_widths=8).to(device)
anchor_grid = get_anchor_grid(
anchor_widths=[8, 16, 32, 64, 96, 128, 160, 192],
aspect_ratios=[1 / 2, 2 / 3, 1, 4 / 3, 5 / 3, 2, 2.5, 3],
num_rows=7,
num_cols=7,
scale_factor=32,
)
dataloader_train = get_dataloader(
path_to_data=".data/mmp-public-3.2/train",
image_size=224,
batch_size=32,
num_workers=9,
is_test=False,
is_train=True,
anchor_grid=anchor_grid,
)
dataloader_val = get_dataloader(
path_to_data=".data/mmp-public-3.2/val",
image_size=224,
batch_size=32,
num_workers=9,
is_test=False,
is_train=False,
anchor_grid=anchor_grid,
)
_, optimizer = get_criterion_optimizer(model=model)
criterion = NegativeMiningCriterion(enable_negative_mining=True)
criterion_eval = NegativeMiningCriterion(enable_negative_mining=False)
2025-12-02 11:04:47 +01:00
num_epochs = 5
2025-11-16 16:28:13 +01:00
for epoch in range(num_epochs):
train_loss = train(
model=model,
loader=dataloader_train,
criterion=criterion,
optimizer=optimizer,
)
2025-12-02 11:04:47 +01:00
avg_loss = evaluate(
2025-11-16 16:28:13 +01:00
model=model, criterion=criterion_eval, dataloader=dataloader_val
)
2025-12-02 11:04:47 +01:00
_ = evaluate_v2(
model=model, device=device, anchor_grid=anchor_grid, loader=dataloader_train
)
2025-11-16 16:28:13 +01:00
if writer is not None:
writer.add_scalar("Loss/train_epoch", train_loss, epoch)
writer.add_scalar("Loss/eval_epoch", avg_loss, epoch)
if writer is not None:
writer.close()
2025-10-13 14:48:00 +02:00
if __name__ == "__main__":
main()