import argparse import torch import torch.optim as optim import torch.nn as nn from torch.utils.data import DataLoader from torch import Tensor from tqdm import tqdm import datetime from .model import MmpNet from ..a4.anchor_grid import get_anchor_grid from ..a4.dataset import get_dataloader from ..a2.main import get_criterion_optimizer def step( model: MmpNet, criterion, optimizer: optim.Optimizer, img_batch: torch.Tensor, lbl_batch: torch.Tensor, ) -> float: model.train() optimizer.zero_grad() device = next(model.parameters()).device img_batch = img_batch.to(device) lbl_batch = lbl_batch.to(device) outputs = model(img_batch) loss = criterion(outputs, lbl_batch) loss.backward() optimizer.step() return loss.item() def get_random_sampling_mask(labels: torch.Tensor, neg_ratio: float) -> torch.Tensor: """ @param labels: The label tensor that is returned by your data loader. The values are either 0 (negative label) or 1 (positive label). @param neg_ratio: The desired negative/positive ratio. Hint: after computing the mask, check if the neg_ratio is fulfilled. @return: A tensor with the same shape as labels """ # Flatten for easier indexing labels_flat = labels.view(-1) pos_indices = (labels_flat == 1).nonzero(as_tuple=True)[0] neg_indices = (labels_flat == 0).nonzero(as_tuple=True)[0] num_pos = pos_indices.numel() num_neg = neg_indices.numel() num_neg_to_sample = min(int(neg_ratio * num_pos), num_neg) perm = torch.randperm(num_neg, device=labels.device) sampled_neg_indices = neg_indices[perm[:num_neg_to_sample]] mask_flat = torch.zeros_like(labels_flat, dtype=torch.long) mask_flat[pos_indices] = 1 mask_flat[sampled_neg_indices] = 1 # Reshape to original shape mask = mask_flat.view_as(labels) return mask def get_detection_metrics( output: Tensor, labels: torch.Tensor, threshold: float ) -> tuple[float, float, float]: """ Returns precision, recall, f1 for the positive (human) class. """ with torch.no_grad(): probs = torch.softmax(output, dim=-1)[..., 1] preds = probs >= threshold TP = ((preds == 1) & (labels == 1)).sum().item() FP = ((preds == 1) & (labels == 0)).sum().item() FN = ((preds == 0) & (labels == 1)).sum().item() TN = ((preds == 0) & (labels == 0)).sum().item() precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0 neg_precision = TN / (TN + FN) if (TN + FN) > 0 else 0.0 recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0 neg_recall = TN / (TN + FP) if (TN + FP) > 0 else 0.0 f1 = ( 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 ) return precision, recall, f1, neg_precision, neg_recall def evaluate( model: MmpNet, criterion, dataloader: DataLoader, ) -> tuple[float, float, float, float]: device = next(model.parameters()).device model.eval() total_loss = 0.0 total_samples = 0 all_outputs = [] all_labels = [] with torch.no_grad(): for img_batch, lbl_batch, _ in dataloader: img_batch = img_batch.to(device) lbl_batch = lbl_batch.to(device) outputs = model(img_batch) loss = criterion(outputs, lbl_batch) batch_size = img_batch.size(0) total_loss += loss.item() * batch_size total_samples += batch_size all_outputs.append(outputs.cpu()) all_labels.append(lbl_batch.cpu()) avg_loss = total_loss / total_samples if total_samples > 0 else 0.0 if all_outputs and all_labels: outputs_cat = torch.cat(all_outputs) labels_cat = torch.cat(all_labels) precision, recall, f1, neg_precision, neg_recall = get_detection_metrics( outputs_cat, labels_cat, threshold=0.5 ) else: precision = recall = f1 = 0.0 return avg_loss, precision, recall, f1, neg_precision, neg_recall def train( model: MmpNet, loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer, ): model.train() running_loss = 0.0 total_samples = 0 progress_bar = tqdm(loader, desc="Training", unit="batch") for img_batch, lbl_batch, _ in progress_bar: loss = step( model=model, criterion=criterion, optimizer=optimizer, img_batch=img_batch, lbl_batch=lbl_batch, ) batch_size = img_batch.size(0) running_loss += loss * batch_size total_samples += batch_size progress_bar.set_postfix( {"loss": running_loss / total_samples if total_samples > 0 else 0.0} ) epoch_loss = running_loss / total_samples if total_samples > 0 else 0.0 progress_bar.close() return epoch_loss class NegativeMiningCriterion(nn.Module): def __init__(self, neg_ratio=3.0, enable_negative_mining: bool = True): super().__init__() self.backbone = nn.CrossEntropyLoss(reduction="none") self.neg_ratio = neg_ratio self.enable_negative_mining = enable_negative_mining def forward(self, outputs, labels): outputs_flat = outputs.view(-1, outputs.shape[-1]) labels_flat = labels.view(-1).long() unfiltered = self.backbone(outputs_flat, labels_flat) assert unfiltered.shape == labels_flat.shape if not self.enable_negative_mining: return unfiltered.mean() mask = get_random_sampling_mask(labels_flat, self.neg_ratio) filtered_loss = unfiltered[mask == 1] return filtered_loss.mean() def main(): parser = argparse.ArgumentParser() parser.add_argument( "--tensorboard", nargs="?", const=True, default=False, help="Enable TensorBoard logging. If a label is provided, it will be used in the log directory name.", ) args = parser.parse_args() if args.tensorboard: from torch.utils.tensorboard import SummaryWriter timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") if isinstance(args.tensorboard, str): label = args.tensorboard log_dir = f"runs/a5_mmpnet_{label}_{timestamp}" else: log_dir = f"runs/a5_mmpnet_{timestamp}" writer = SummaryWriter(log_dir=log_dir) else: writer = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MmpNet(num_aspect_ratios=8, num_widths=8).to(device) anchor_grid = get_anchor_grid( anchor_widths=[8, 16, 32, 64, 96, 128, 160, 192], aspect_ratios=[1 / 2, 2 / 3, 1, 4 / 3, 5 / 3, 2, 2.5, 3], num_rows=7, num_cols=7, scale_factor=32, ) dataloader_train = get_dataloader( path_to_data=".data/mmp-public-3.2/train", image_size=224, batch_size=32, num_workers=9, is_test=False, is_train=True, anchor_grid=anchor_grid, ) dataloader_val = get_dataloader( path_to_data=".data/mmp-public-3.2/val", image_size=224, batch_size=32, num_workers=9, is_test=False, is_train=False, anchor_grid=anchor_grid, ) _, optimizer = get_criterion_optimizer(model=model) criterion = NegativeMiningCriterion(enable_negative_mining=True) criterion_eval = NegativeMiningCriterion(enable_negative_mining=False) num_epochs = 7 for epoch in range(num_epochs): train_loss = train( model=model, loader=dataloader_train, criterion=criterion, optimizer=optimizer, ) avg_loss, precision, recall, f1, neg_precision, neg_recall = evaluate( model=model, criterion=criterion_eval, dataloader=dataloader_val ) if writer is not None: writer.add_scalar("Loss/train_epoch", train_loss, epoch) writer.add_scalar("Loss/eval_epoch", avg_loss, epoch) writer.add_scalar("Acc/precision", precision, epoch) writer.add_scalar("Acc/recall", recall, epoch) writer.add_scalar("Acc/neg_precision", neg_precision, epoch) writer.add_scalar("Acc/neg_recall", neg_recall, epoch) writer.add_scalar("Acc/f1", f1, epoch) if writer is not None: writer.close() if __name__ == "__main__": main()