2025-11-16 16:28:13 +01:00
|
|
|
import argparse
|
2025-10-13 14:48:00 +02:00
|
|
|
import torch
|
|
|
|
|
import torch.optim as optim
|
2025-11-16 16:28:13 +01:00
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
import datetime
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
from .model import MmpNet
|
2025-11-16 16:28:13 +01:00
|
|
|
from ..a4.anchor_grid import get_anchor_grid
|
|
|
|
|
from ..a4.dataset import get_dataloader
|
|
|
|
|
from ..a2.main import get_criterion_optimizer
|
2025-12-02 11:04:47 +01:00
|
|
|
from ..a6.main import evaluate as evaluate_v2
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def step(
|
|
|
|
|
model: MmpNet,
|
|
|
|
|
criterion,
|
|
|
|
|
optimizer: optim.Optimizer,
|
|
|
|
|
img_batch: torch.Tensor,
|
|
|
|
|
lbl_batch: torch.Tensor,
|
|
|
|
|
) -> float:
|
2025-11-16 16:28:13 +01:00
|
|
|
model.train()
|
|
|
|
|
optimizer.zero_grad()
|
2025-10-13 14:48:00 +02:00
|
|
|
|
2025-11-16 16:28:13 +01:00
|
|
|
device = next(model.parameters()).device
|
|
|
|
|
img_batch = img_batch.to(device)
|
|
|
|
|
lbl_batch = lbl_batch.to(device)
|
|
|
|
|
|
|
|
|
|
outputs = model(img_batch)
|
|
|
|
|
loss = criterion(outputs, lbl_batch)
|
|
|
|
|
loss.backward()
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
return loss.item()
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_random_sampling_mask(labels: torch.Tensor, neg_ratio: float) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
@param labels: The label tensor that is returned by your data loader.
|
|
|
|
|
The values are either 0 (negative label) or 1 (positive label).
|
|
|
|
|
@param neg_ratio: The desired negative/positive ratio.
|
|
|
|
|
Hint: after computing the mask, check if the neg_ratio is fulfilled.
|
|
|
|
|
@return: A tensor with the same shape as labels
|
|
|
|
|
"""
|
2025-11-16 16:28:13 +01:00
|
|
|
# Flatten for easier indexing
|
|
|
|
|
labels_flat = labels.view(-1)
|
|
|
|
|
pos_indices = (labels_flat == 1).nonzero(as_tuple=True)[0]
|
|
|
|
|
neg_indices = (labels_flat == 0).nonzero(as_tuple=True)[0]
|
|
|
|
|
|
|
|
|
|
num_pos = pos_indices.numel()
|
|
|
|
|
num_neg = neg_indices.numel()
|
|
|
|
|
|
|
|
|
|
num_neg_to_sample = min(int(neg_ratio * num_pos), num_neg)
|
|
|
|
|
|
|
|
|
|
perm = torch.randperm(num_neg, device=labels.device)
|
|
|
|
|
sampled_neg_indices = neg_indices[perm[:num_neg_to_sample]]
|
|
|
|
|
|
|
|
|
|
mask_flat = torch.zeros_like(labels_flat, dtype=torch.long)
|
|
|
|
|
mask_flat[pos_indices] = 1
|
|
|
|
|
mask_flat[sampled_neg_indices] = 1
|
|
|
|
|
|
|
|
|
|
# Reshape to original shape
|
|
|
|
|
mask = mask_flat.view_as(labels)
|
|
|
|
|
return mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(
|
|
|
|
|
model: MmpNet,
|
|
|
|
|
criterion,
|
|
|
|
|
dataloader: DataLoader,
|
2025-12-02 11:04:47 +01:00
|
|
|
) -> float:
|
2025-11-16 16:28:13 +01:00
|
|
|
device = next(model.parameters()).device
|
|
|
|
|
model.eval()
|
|
|
|
|
total_loss = 0.0
|
|
|
|
|
total_samples = 0
|
|
|
|
|
all_outputs = []
|
|
|
|
|
all_labels = []
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
for img_batch, lbl_batch, _ in dataloader:
|
|
|
|
|
img_batch = img_batch.to(device)
|
|
|
|
|
lbl_batch = lbl_batch.to(device)
|
|
|
|
|
|
|
|
|
|
outputs = model(img_batch)
|
|
|
|
|
loss = criterion(outputs, lbl_batch)
|
|
|
|
|
batch_size = img_batch.size(0)
|
|
|
|
|
total_loss += loss.item() * batch_size
|
|
|
|
|
total_samples += batch_size
|
|
|
|
|
|
|
|
|
|
all_outputs.append(outputs.cpu())
|
|
|
|
|
all_labels.append(lbl_batch.cpu())
|
|
|
|
|
avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
|
2025-12-02 11:04:47 +01:00
|
|
|
return avg_loss
|
2025-11-16 16:28:13 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(
|
|
|
|
|
model: MmpNet,
|
|
|
|
|
loader: DataLoader,
|
|
|
|
|
criterion: nn.Module,
|
|
|
|
|
optimizer: optim.Optimizer,
|
|
|
|
|
):
|
|
|
|
|
model.train()
|
|
|
|
|
running_loss = 0.0
|
|
|
|
|
total_samples = 0
|
|
|
|
|
|
|
|
|
|
progress_bar = tqdm(loader, desc="Training", unit="batch")
|
|
|
|
|
|
|
|
|
|
for img_batch, lbl_batch, _ in progress_bar:
|
|
|
|
|
loss = step(
|
|
|
|
|
model=model,
|
|
|
|
|
criterion=criterion,
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
img_batch=img_batch,
|
|
|
|
|
lbl_batch=lbl_batch,
|
|
|
|
|
)
|
|
|
|
|
batch_size = img_batch.size(0)
|
|
|
|
|
running_loss += loss * batch_size
|
|
|
|
|
total_samples += batch_size
|
|
|
|
|
progress_bar.set_postfix(
|
|
|
|
|
{"loss": running_loss / total_samples if total_samples > 0 else 0.0}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
epoch_loss = running_loss / total_samples if total_samples > 0 else 0.0
|
|
|
|
|
progress_bar.close()
|
|
|
|
|
return epoch_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NegativeMiningCriterion(nn.Module):
|
|
|
|
|
def __init__(self, neg_ratio=3.0, enable_negative_mining: bool = True):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.backbone = nn.CrossEntropyLoss(reduction="none")
|
|
|
|
|
self.neg_ratio = neg_ratio
|
|
|
|
|
self.enable_negative_mining = enable_negative_mining
|
|
|
|
|
|
|
|
|
|
def forward(self, outputs, labels):
|
|
|
|
|
outputs_flat = outputs.view(-1, outputs.shape[-1])
|
|
|
|
|
labels_flat = labels.view(-1).long()
|
|
|
|
|
unfiltered = self.backbone(outputs_flat, labels_flat)
|
|
|
|
|
assert unfiltered.shape == labels_flat.shape
|
|
|
|
|
|
|
|
|
|
if not self.enable_negative_mining:
|
|
|
|
|
return unfiltered.mean()
|
|
|
|
|
|
|
|
|
|
mask = get_random_sampling_mask(labels_flat, self.neg_ratio)
|
|
|
|
|
filtered_loss = unfiltered[mask == 1]
|
|
|
|
|
return filtered_loss.mean()
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
2025-11-16 16:28:13 +01:00
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--tensorboard",
|
|
|
|
|
nargs="?",
|
|
|
|
|
const=True,
|
|
|
|
|
default=False,
|
|
|
|
|
help="Enable TensorBoard logging. If a label is provided, it will be used in the log directory name.",
|
|
|
|
|
)
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
if args.tensorboard:
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
|
|
|
if isinstance(args.tensorboard, str):
|
|
|
|
|
label = args.tensorboard
|
|
|
|
|
log_dir = f"runs/a5_mmpnet_{label}_{timestamp}"
|
|
|
|
|
else:
|
|
|
|
|
log_dir = f"runs/a5_mmpnet_{timestamp}"
|
|
|
|
|
writer = SummaryWriter(log_dir=log_dir)
|
|
|
|
|
else:
|
|
|
|
|
writer = None
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
model = MmpNet(num_aspect_ratios=8, num_widths=8).to(device)
|
|
|
|
|
|
|
|
|
|
anchor_grid = get_anchor_grid(
|
|
|
|
|
anchor_widths=[8, 16, 32, 64, 96, 128, 160, 192],
|
|
|
|
|
aspect_ratios=[1 / 2, 2 / 3, 1, 4 / 3, 5 / 3, 2, 2.5, 3],
|
|
|
|
|
num_rows=7,
|
|
|
|
|
num_cols=7,
|
|
|
|
|
scale_factor=32,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
dataloader_train = get_dataloader(
|
|
|
|
|
path_to_data=".data/mmp-public-3.2/train",
|
|
|
|
|
image_size=224,
|
|
|
|
|
batch_size=32,
|
|
|
|
|
num_workers=9,
|
|
|
|
|
is_test=False,
|
|
|
|
|
is_train=True,
|
|
|
|
|
anchor_grid=anchor_grid,
|
|
|
|
|
)
|
|
|
|
|
dataloader_val = get_dataloader(
|
|
|
|
|
path_to_data=".data/mmp-public-3.2/val",
|
|
|
|
|
image_size=224,
|
|
|
|
|
batch_size=32,
|
|
|
|
|
num_workers=9,
|
|
|
|
|
is_test=False,
|
|
|
|
|
is_train=False,
|
|
|
|
|
anchor_grid=anchor_grid,
|
|
|
|
|
)
|
|
|
|
|
_, optimizer = get_criterion_optimizer(model=model)
|
|
|
|
|
criterion = NegativeMiningCriterion(enable_negative_mining=True)
|
|
|
|
|
criterion_eval = NegativeMiningCriterion(enable_negative_mining=False)
|
2025-12-02 11:04:47 +01:00
|
|
|
num_epochs = 5
|
2025-11-16 16:28:13 +01:00
|
|
|
|
|
|
|
|
for epoch in range(num_epochs):
|
|
|
|
|
train_loss = train(
|
|
|
|
|
model=model,
|
|
|
|
|
loader=dataloader_train,
|
|
|
|
|
criterion=criterion,
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
)
|
2025-12-02 11:04:47 +01:00
|
|
|
avg_loss = evaluate(
|
2025-11-16 16:28:13 +01:00
|
|
|
model=model, criterion=criterion_eval, dataloader=dataloader_val
|
|
|
|
|
)
|
2025-12-02 11:04:47 +01:00
|
|
|
_ = evaluate_v2(
|
|
|
|
|
model=model, device=device, anchor_grid=anchor_grid, loader=dataloader_train
|
|
|
|
|
)
|
2025-11-16 16:28:13 +01:00
|
|
|
|
|
|
|
|
if writer is not None:
|
|
|
|
|
writer.add_scalar("Loss/train_epoch", train_loss, epoch)
|
|
|
|
|
writer.add_scalar("Loss/eval_epoch", avg_loss, epoch)
|
|
|
|
|
|
|
|
|
|
if writer is not None:
|
|
|
|
|
writer.close()
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|