adapts metrics
This commit is contained in:
@@ -67,9 +67,9 @@ def get_random_sampling_mask(labels: torch.Tensor, neg_ratio: float) -> torch.Te
|
|||||||
|
|
||||||
def get_detection_metrics(
|
def get_detection_metrics(
|
||||||
output: Tensor, labels: torch.Tensor, threshold: float
|
output: Tensor, labels: torch.Tensor, threshold: float
|
||||||
) -> tuple[float, float, float]:
|
) -> tuple[float, float, float, float]:
|
||||||
"""
|
"""
|
||||||
Returns precision, recall, f1 for the positive (human) class.
|
Returns precision, recall, f1 for the positive (human) class, and overall accuracy.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
probs = torch.softmax(output, dim=-1)[..., 1]
|
probs = torch.softmax(output, dim=-1)[..., 1]
|
||||||
@@ -82,16 +82,20 @@ def get_detection_metrics(
|
|||||||
TN = ((preds == 0) & (labels == 0)).sum().item()
|
TN = ((preds == 0) & (labels == 0)).sum().item()
|
||||||
|
|
||||||
precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
|
precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
|
||||||
neg_precision = TN / (TN + FN) if (TN + FN) > 0 else 0.0
|
|
||||||
recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
|
recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
|
||||||
neg_recall = TN / (TN + FP) if (TN + FP) > 0 else 0.0
|
|
||||||
f1 = (
|
f1 = (
|
||||||
2 * precision * recall / (precision + recall)
|
2 * precision * recall / (precision + recall)
|
||||||
if (precision + recall) > 0
|
if (precision + recall) > 0
|
||||||
else 0.0
|
else 0.0
|
||||||
)
|
)
|
||||||
|
accuracy = (TP + TN) / (TP + TN + FP + FN) if (TP + TN + FP + FN) > 0 else 0.0
|
||||||
|
|
||||||
return precision, recall, f1, neg_precision, neg_recall
|
return (
|
||||||
|
precision,
|
||||||
|
recall,
|
||||||
|
f1,
|
||||||
|
accuracy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
@@ -122,12 +126,12 @@ def evaluate(
|
|||||||
if all_outputs and all_labels:
|
if all_outputs and all_labels:
|
||||||
outputs_cat = torch.cat(all_outputs)
|
outputs_cat = torch.cat(all_outputs)
|
||||||
labels_cat = torch.cat(all_labels)
|
labels_cat = torch.cat(all_labels)
|
||||||
precision, recall, f1, neg_precision, neg_recall = get_detection_metrics(
|
precision, recall, f1, acc = get_detection_metrics(
|
||||||
outputs_cat, labels_cat, threshold=0.5
|
outputs_cat, labels_cat, threshold=0.5
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
precision = recall = f1 = 0.0
|
precision = recall = f1 = 0.0
|
||||||
return avg_loss, precision, recall, f1, neg_precision, neg_recall
|
return avg_loss, precision, recall, f1, acc
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@@ -239,7 +243,7 @@ def main():
|
|||||||
_, optimizer = get_criterion_optimizer(model=model)
|
_, optimizer = get_criterion_optimizer(model=model)
|
||||||
criterion = NegativeMiningCriterion(enable_negative_mining=True)
|
criterion = NegativeMiningCriterion(enable_negative_mining=True)
|
||||||
criterion_eval = NegativeMiningCriterion(enable_negative_mining=False)
|
criterion_eval = NegativeMiningCriterion(enable_negative_mining=False)
|
||||||
num_epochs = 7
|
num_epochs = 10
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
train_loss = train(
|
train_loss = train(
|
||||||
@@ -248,7 +252,7 @@ def main():
|
|||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
avg_loss, precision, recall, f1, neg_precision, neg_recall = evaluate(
|
avg_loss, precision, recall, f1, acc = evaluate(
|
||||||
model=model, criterion=criterion_eval, dataloader=dataloader_val
|
model=model, criterion=criterion_eval, dataloader=dataloader_val
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -257,8 +261,7 @@ def main():
|
|||||||
writer.add_scalar("Loss/eval_epoch", avg_loss, epoch)
|
writer.add_scalar("Loss/eval_epoch", avg_loss, epoch)
|
||||||
writer.add_scalar("Acc/precision", precision, epoch)
|
writer.add_scalar("Acc/precision", precision, epoch)
|
||||||
writer.add_scalar("Acc/recall", recall, epoch)
|
writer.add_scalar("Acc/recall", recall, epoch)
|
||||||
writer.add_scalar("Acc/neg_precision", neg_precision, epoch)
|
writer.add_scalar("Acc/acc", acc, epoch)
|
||||||
writer.add_scalar("Acc/neg_recall", neg_recall, epoch)
|
|
||||||
writer.add_scalar("Acc/f1", f1, epoch)
|
writer.add_scalar("Acc/f1", f1, epoch)
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user