adds nms and eval
This commit is contained in:
@@ -3,7 +3,6 @@ import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
import datetime
|
||||
|
||||
@@ -11,6 +10,7 @@ from .model import MmpNet
|
||||
from ..a4.anchor_grid import get_anchor_grid
|
||||
from ..a4.dataset import get_dataloader
|
||||
from ..a2.main import get_criterion_optimizer
|
||||
from ..a6.main import evaluate as evaluate_v2
|
||||
|
||||
|
||||
def step(
|
||||
@@ -65,44 +65,11 @@ def get_random_sampling_mask(labels: torch.Tensor, neg_ratio: float) -> torch.Te
|
||||
return mask
|
||||
|
||||
|
||||
def get_detection_metrics(
|
||||
output: Tensor, labels: torch.Tensor, threshold: float
|
||||
) -> tuple[float, float, float, float]:
|
||||
"""
|
||||
Returns precision, recall, f1 for the positive (human) class, and overall accuracy.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
probs = torch.softmax(output, dim=-1)[..., 1]
|
||||
|
||||
preds = probs >= threshold
|
||||
|
||||
TP = ((preds == 1) & (labels == 1)).sum().item()
|
||||
FP = ((preds == 1) & (labels == 0)).sum().item()
|
||||
FN = ((preds == 0) & (labels == 1)).sum().item()
|
||||
TN = ((preds == 0) & (labels == 0)).sum().item()
|
||||
|
||||
precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
|
||||
recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
|
||||
f1 = (
|
||||
2 * precision * recall / (precision + recall)
|
||||
if (precision + recall) > 0
|
||||
else 0.0
|
||||
)
|
||||
accuracy = (TP + TN) / (TP + TN + FP + FN) if (TP + TN + FP + FN) > 0 else 0.0
|
||||
|
||||
return (
|
||||
precision,
|
||||
recall,
|
||||
f1,
|
||||
accuracy,
|
||||
)
|
||||
|
||||
|
||||
def evaluate(
|
||||
model: MmpNet,
|
||||
criterion,
|
||||
dataloader: DataLoader,
|
||||
) -> tuple[float, float, float, float]:
|
||||
) -> float:
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
@@ -123,15 +90,7 @@ def evaluate(
|
||||
all_outputs.append(outputs.cpu())
|
||||
all_labels.append(lbl_batch.cpu())
|
||||
avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
|
||||
if all_outputs and all_labels:
|
||||
outputs_cat = torch.cat(all_outputs)
|
||||
labels_cat = torch.cat(all_labels)
|
||||
precision, recall, f1, acc = get_detection_metrics(
|
||||
outputs_cat, labels_cat, threshold=0.5
|
||||
)
|
||||
else:
|
||||
precision = recall = f1 = 0.0
|
||||
return avg_loss, precision, recall, f1, acc
|
||||
return avg_loss
|
||||
|
||||
|
||||
def train(
|
||||
@@ -243,7 +202,7 @@ def main():
|
||||
_, optimizer = get_criterion_optimizer(model=model)
|
||||
criterion = NegativeMiningCriterion(enable_negative_mining=True)
|
||||
criterion_eval = NegativeMiningCriterion(enable_negative_mining=False)
|
||||
num_epochs = 10
|
||||
num_epochs = 5
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
train_loss = train(
|
||||
@@ -252,17 +211,16 @@ def main():
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
avg_loss, precision, recall, f1, acc = evaluate(
|
||||
avg_loss = evaluate(
|
||||
model=model, criterion=criterion_eval, dataloader=dataloader_val
|
||||
)
|
||||
_ = evaluate_v2(
|
||||
model=model, device=device, anchor_grid=anchor_grid, loader=dataloader_train
|
||||
)
|
||||
|
||||
if writer is not None:
|
||||
writer.add_scalar("Loss/train_epoch", train_loss, epoch)
|
||||
writer.add_scalar("Loss/eval_epoch", avg_loss, epoch)
|
||||
writer.add_scalar("Acc/precision", precision, epoch)
|
||||
writer.add_scalar("Acc/recall", recall, epoch)
|
||||
writer.add_scalar("Acc/acc", acc, epoch)
|
||||
writer.add_scalar("Acc/f1", f1, epoch)
|
||||
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
|
||||
Reference in New Issue
Block a user