import argparse import torch import torch.optim as optim import torch.nn as nn from torch.utils.data import DataLoader from tqdm import tqdm import datetime from .model import MmpNet from ..a4.anchor_grid import get_anchor_grid from ..a4.dataset import get_dataloader from ..a2.main import get_criterion_optimizer from ..a6.main import evaluate as evaluate_v2 def step( model: MmpNet, criterion, optimizer: optim.Optimizer, img_batch: torch.Tensor, lbl_batch: torch.Tensor, ) -> float: model.train() optimizer.zero_grad() device = next(model.parameters()).device img_batch = img_batch.to(device) lbl_batch = lbl_batch.to(device) outputs = model(img_batch) loss = criterion(outputs, lbl_batch) loss.backward() optimizer.step() return loss.item() def get_random_sampling_mask(labels: torch.Tensor, neg_ratio: float) -> torch.Tensor: """ @param labels: The label tensor that is returned by your data loader. The values are either 0 (negative label) or 1 (positive label). @param neg_ratio: The desired negative/positive ratio. Hint: after computing the mask, check if the neg_ratio is fulfilled. @return: A tensor with the same shape as labels """ # Flatten for easier indexing labels_flat = labels.view(-1) pos_indices = (labels_flat == 1).nonzero(as_tuple=True)[0] neg_indices = (labels_flat == 0).nonzero(as_tuple=True)[0] num_pos = pos_indices.numel() num_neg = neg_indices.numel() num_neg_to_sample = min(int(neg_ratio * num_pos), num_neg) perm = torch.randperm(num_neg, device=labels.device) sampled_neg_indices = neg_indices[perm[:num_neg_to_sample]] mask_flat = torch.zeros_like(labels_flat, dtype=torch.long) mask_flat[pos_indices] = 1 mask_flat[sampled_neg_indices] = 1 # Reshape to original shape mask = mask_flat.view_as(labels) return mask def evaluate( model: MmpNet, criterion, dataloader: DataLoader, ) -> float: device = next(model.parameters()).device model.eval() total_loss = 0.0 total_samples = 0 all_outputs = [] all_labels = [] with torch.no_grad(): for img_batch, lbl_batch, _ in dataloader: img_batch = img_batch.to(device) lbl_batch = lbl_batch.to(device) outputs = model(img_batch) loss = criterion(outputs, lbl_batch) batch_size = img_batch.size(0) total_loss += loss.item() * batch_size total_samples += batch_size all_outputs.append(outputs.cpu()) all_labels.append(lbl_batch.cpu()) avg_loss = total_loss / total_samples if total_samples > 0 else 0.0 return avg_loss def train( model: MmpNet, loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer, ): model.train() running_loss = 0.0 total_samples = 0 progress_bar = tqdm(loader, desc="Training", unit="batch") for img_batch, lbl_batch, _ in progress_bar: loss = step( model=model, criterion=criterion, optimizer=optimizer, img_batch=img_batch, lbl_batch=lbl_batch, ) batch_size = img_batch.size(0) running_loss += loss * batch_size total_samples += batch_size progress_bar.set_postfix( {"loss": running_loss / total_samples if total_samples > 0 else 0.0} ) epoch_loss = running_loss / total_samples if total_samples > 0 else 0.0 progress_bar.close() return epoch_loss class NegativeMiningCriterion(nn.Module): def __init__(self, neg_ratio=3.0, enable_negative_mining: bool = True): super().__init__() self.backbone = nn.CrossEntropyLoss(reduction="none") self.neg_ratio = neg_ratio self.enable_negative_mining = enable_negative_mining def forward(self, outputs, labels): outputs_flat = outputs.view(-1, outputs.shape[-1]) labels_flat = labels.view(-1).long() unfiltered = self.backbone(outputs_flat, labels_flat) assert unfiltered.shape == labels_flat.shape if not self.enable_negative_mining: return unfiltered.mean() mask = get_random_sampling_mask(labels_flat, self.neg_ratio) filtered_loss = unfiltered[mask == 1] return filtered_loss.mean() def main(): parser = argparse.ArgumentParser() parser.add_argument( "--tensorboard", nargs="?", const=True, default=False, help="Enable TensorBoard logging. If a label is provided, it will be used in the log directory name.", ) args = parser.parse_args() if args.tensorboard: from torch.utils.tensorboard import SummaryWriter timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") if isinstance(args.tensorboard, str): label = args.tensorboard log_dir = f"runs/a5_mmpnet_{label}_{timestamp}" else: log_dir = f"runs/a5_mmpnet_{timestamp}" writer = SummaryWriter(log_dir=log_dir) else: writer = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MmpNet(num_aspect_ratios=8, num_widths=8).to(device) anchor_grid = get_anchor_grid( anchor_widths=[8, 16, 32, 64, 96, 128, 160, 192], aspect_ratios=[1 / 2, 2 / 3, 1, 4 / 3, 5 / 3, 2, 2.5, 3], num_rows=7, num_cols=7, scale_factor=32, ) dataloader_train = get_dataloader( path_to_data=".data/mmp-public-3.2/train", image_size=224, batch_size=32, num_workers=9, is_test=False, is_train=True, anchor_grid=anchor_grid, ) dataloader_val = get_dataloader( path_to_data=".data/mmp-public-3.2/val", image_size=224, batch_size=32, num_workers=9, is_test=False, is_train=False, anchor_grid=anchor_grid, ) _, optimizer = get_criterion_optimizer(model=model) criterion = NegativeMiningCriterion(enable_negative_mining=True) criterion_eval = NegativeMiningCriterion(enable_negative_mining=False) num_epochs = 5 for epoch in range(num_epochs): train_loss = train( model=model, loader=dataloader_train, criterion=criterion, optimizer=optimizer, ) avg_loss = evaluate( model=model, criterion=criterion_eval, dataloader=dataloader_val ) _ = evaluate_v2( model=model, device=device, anchor_grid=anchor_grid, loader=dataloader_train ) if writer is not None: writer.add_scalar("Loss/train_epoch", train_loss, epoch) writer.add_scalar("Loss/eval_epoch", avg_loss, epoch) if writer is not None: writer.close() if __name__ == "__main__": main()