adds nms and eval

This commit is contained in:
franksim
2025-12-02 11:04:47 +01:00
parent 3b6a588719
commit a6f70005f2
9 changed files with 428 additions and 985 deletions

View File

@@ -5,7 +5,7 @@ import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from ..a3.annotation import read_groundtruth_file, AnnotationRect from ..a3.annotation import read_groundtruth_file, AnnotationRect
from .label_grid import get_label_grid, iou from .label_grid import get_label_grid
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.patches as patches import matplotlib.patches as patches
from .anchor_grid import get_anchor_grid from .anchor_grid import get_anchor_grid
@@ -178,7 +178,6 @@ def compute_ious_vectorized(boxes1, boxes2):
boxes1: (M, 4), boxes2: (N, 4) -- format [x1, y1, x2, y2] boxes1: (M, 4), boxes2: (N, 4) -- format [x1, y1, x2, y2]
Returns: (M, N) IoU Returns: (M, N) IoU
""" """
M, N = boxes1.shape[0], boxes2.shape[0]
# Expand to (M, N, 4) # Expand to (M, N, 4)
boxes1 = boxes1[:, None, :] # (M, 1, 4) boxes1 = boxes1[:, None, :] # (M, 1, 4)

View File

@@ -3,7 +3,6 @@ import torch
import torch.optim as optim import torch.optim as optim
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch import Tensor
from tqdm import tqdm from tqdm import tqdm
import datetime import datetime
@@ -11,6 +10,7 @@ from .model import MmpNet
from ..a4.anchor_grid import get_anchor_grid from ..a4.anchor_grid import get_anchor_grid
from ..a4.dataset import get_dataloader from ..a4.dataset import get_dataloader
from ..a2.main import get_criterion_optimizer from ..a2.main import get_criterion_optimizer
from ..a6.main import evaluate as evaluate_v2
def step( def step(
@@ -65,44 +65,11 @@ def get_random_sampling_mask(labels: torch.Tensor, neg_ratio: float) -> torch.Te
return mask return mask
def get_detection_metrics(
output: Tensor, labels: torch.Tensor, threshold: float
) -> tuple[float, float, float, float]:
"""
Returns precision, recall, f1 for the positive (human) class, and overall accuracy.
"""
with torch.no_grad():
probs = torch.softmax(output, dim=-1)[..., 1]
preds = probs >= threshold
TP = ((preds == 1) & (labels == 1)).sum().item()
FP = ((preds == 1) & (labels == 0)).sum().item()
FN = ((preds == 0) & (labels == 1)).sum().item()
TN = ((preds == 0) & (labels == 0)).sum().item()
precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0.0
)
accuracy = (TP + TN) / (TP + TN + FP + FN) if (TP + TN + FP + FN) > 0 else 0.0
return (
precision,
recall,
f1,
accuracy,
)
def evaluate( def evaluate(
model: MmpNet, model: MmpNet,
criterion, criterion,
dataloader: DataLoader, dataloader: DataLoader,
) -> tuple[float, float, float, float]: ) -> float:
device = next(model.parameters()).device device = next(model.parameters()).device
model.eval() model.eval()
total_loss = 0.0 total_loss = 0.0
@@ -123,15 +90,7 @@ def evaluate(
all_outputs.append(outputs.cpu()) all_outputs.append(outputs.cpu())
all_labels.append(lbl_batch.cpu()) all_labels.append(lbl_batch.cpu())
avg_loss = total_loss / total_samples if total_samples > 0 else 0.0 avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
if all_outputs and all_labels: return avg_loss
outputs_cat = torch.cat(all_outputs)
labels_cat = torch.cat(all_labels)
precision, recall, f1, acc = get_detection_metrics(
outputs_cat, labels_cat, threshold=0.5
)
else:
precision = recall = f1 = 0.0
return avg_loss, precision, recall, f1, acc
def train( def train(
@@ -243,7 +202,7 @@ def main():
_, optimizer = get_criterion_optimizer(model=model) _, optimizer = get_criterion_optimizer(model=model)
criterion = NegativeMiningCriterion(enable_negative_mining=True) criterion = NegativeMiningCriterion(enable_negative_mining=True)
criterion_eval = NegativeMiningCriterion(enable_negative_mining=False) criterion_eval = NegativeMiningCriterion(enable_negative_mining=False)
num_epochs = 10 num_epochs = 5
for epoch in range(num_epochs): for epoch in range(num_epochs):
train_loss = train( train_loss = train(
@@ -252,17 +211,16 @@ def main():
criterion=criterion, criterion=criterion,
optimizer=optimizer, optimizer=optimizer,
) )
avg_loss, precision, recall, f1, acc = evaluate( avg_loss = evaluate(
model=model, criterion=criterion_eval, dataloader=dataloader_val model=model, criterion=criterion_eval, dataloader=dataloader_val
) )
_ = evaluate_v2(
model=model, device=device, anchor_grid=anchor_grid, loader=dataloader_train
)
if writer is not None: if writer is not None:
writer.add_scalar("Loss/train_epoch", train_loss, epoch) writer.add_scalar("Loss/train_epoch", train_loss, epoch)
writer.add_scalar("Loss/eval_epoch", avg_loss, epoch) writer.add_scalar("Loss/eval_epoch", avg_loss, epoch)
writer.add_scalar("Acc/precision", precision, epoch)
writer.add_scalar("Acc/recall", recall, epoch)
writer.add_scalar("Acc/acc", acc, epoch)
writer.add_scalar("Acc/f1", f1, epoch)
if writer is not None: if writer is not None:
writer.close() writer.close()

0
mmp/a6/__init__.py Normal file
View File

View File

@@ -1,25 +1,144 @@
from typing import List, Tuple from typing import List, Tuple
import torch import torch
import numpy as np import numpy as np
from tqdm import tqdm
import os
from torch.utils.data import DataLoader
from mmp.a6.evallib import calculate_ap_pr
from ..a4.label_grid import iou
from ..a5.model import MmpNet from ..a5.model import MmpNet
from ..a3.annotation import AnnotationRect from ..a3.annotation import AnnotationRect, read_groundtruth_file
from .nms import non_maximum_suppression
def batch_inference( def batch_inference(
model: MmpNet, images: torch.Tensor, device: torch.device, anchor_grid: np.ndarray model: MmpNet, images: torch.Tensor, device: torch.device, anchor_grid: np.ndarray
) -> List[List[Tuple[AnnotationRect, float]]]: ) -> List[List[Tuple[AnnotationRect, float]]]:
raise NotImplementedError() score_thresh = 0.5
nms_thresh = 0.3
model = model.to(device)
model.eval()
images = images.to(device)
anchor_grid = anchor_grid # shape [W, R, h, w, 4]
results = []
with torch.no_grad():
outputs = model(images) # (B, W, R, h, w, 2)
probs = torch.softmax(outputs, dim=-1)[..., 1] # (B, W, R, h, w)
probs_np = probs.cpu().numpy()
batch_size = outputs.shape[0]
for b in range(batch_size):
detections = []
for idx in np.ndindex(anchor_grid.shape[:-1]):
score = probs_np[b][idx]
# if score >= score_thresh:
box = anchor_grid[idx]
rect = AnnotationRect.fromarray(box)
detections.append((rect, float(score)))
detections_nms = non_maximum_suppression(detections, nms_thresh)
results.append(detections_nms)
return results
def evaluate() -> float: # feel free to change the arguments def evaluate(
model: MmpNet, loader: DataLoader, device: torch.device, anchor_grid: np.ndarray
) -> float:
"""Evaluates a specified model on the whole validation dataset. """Evaluates a specified model on the whole validation dataset.
@return: AP for the validation set as a float. @return: AP for the validation set as a float.
You decide which arguments this function should receive You decide which arguments this function should receive
""" """
raise NotImplementedError()
path_to_data = ".data/mmp-public-3.2/train"
progress_bar = tqdm(loader, desc="Evaluation", unit="batch")
image_count = 0
ap_total = 0
for img_batch, _, id_batch in progress_bar:
inference = batch_inference(
anchor_grid=anchor_grid, device=device, images=img_batch, model=model
)
gts = get_gts_for_batch(id_batch=id_batch, gt_base_path=path_to_data)
dict_detections = {
img_id.item(): inference[idx] for idx, img_id in enumerate(id_batch)
}
dict_gt = {img_id.item(): gts[idx] for idx, img_id in enumerate(id_batch)}
average_prevision, precision, recall = calculate_ap_pr(dict_detections, dict_gt)
ap_total = (ap_total * image_count + average_prevision) / (
image_count + id_batch.shape[0]
)
image_count += id_batch.shape[0]
progress_bar.set_postfix(
{
"ap": ap_total,
}
)
return ap_total
def get_gts_for_batch(
id_batch: torch.Tensor, gt_base_path: str
) -> List[List[AnnotationRect]]:
return [
read_groundtruth_file(
os.path.join(gt_base_path, f"{str(img_id.item()).zfill(8)}.gt_data.txt")
)
for img_id in id_batch
]
def calc_tp_fp_fn(
detections: List[Tuple[AnnotationRect, float]],
gts: List[AnnotationRect],
iou_threshold: float = 0.5,
confidence_threshhold: float = 0.5,
) -> tuple[int, int, int]:
"""
Calculates precision and recall for object detection results on a single image.
Args:
detections: List of (AnnotationRect, confidence) tuples representing predicted boxes and scores. Should be sorted by descending confidence.
gts: List of AnnotationRect for ground truth.
iou_threshold: Minimum IoU to consider a detection a true positive.
confidence_threshhold: Minimum confidence required to include a detection.
Returns:
num_tp: Number of true positives (int).
num_fp: Number of false positives (int).
num_fn: Number of false negatives (int).
"""
detections = [det for det in detections if det[1] >= confidence_threshhold]
detections.sort(key=lambda x: x[1], reverse=True)
matches = set()
fp = 0
tp = 0
for det_rect, _ in detections:
iou_map = [iou(det_rect, gt_rect) for gt_rect in gts]
if len(iou_map) == 0:
fp += 1
continue
max_idx = np.argmax(iou_map)
if max_idx in matches or iou_map[max_idx] < iou_threshold:
fp += 1
continue
matches.add(max_idx)
tp += 1
fn = len(gts) - len(matches)
return tp, fp, fn
def evaluate_test(): # feel free to change the arguments def evaluate_test(): # feel free to change the arguments

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,9 @@
import os
from typing import List, Sequence, Tuple from typing import List, Sequence, Tuple
from ..a3.annotation import AnnotationRect from ..a3.annotation import AnnotationRect
from ..a4.label_grid import iou, draw_annotation_rects
from collections import defaultdict
def non_maximum_suppression( def non_maximum_suppression(
@@ -12,4 +15,68 @@ def non_maximum_suppression(
@return: A list of tuples of the remaining boxes after NMS together with their scores @return: A list of tuples of the remaining boxes after NMS together with their scores
""" """
raise NotImplementedError() if not boxes_scores:
return []
# Sort the boxes by score in descending order
boxes_scores_sorted = sorted(boxes_scores, key=lambda bs: bs[1], reverse=True)
result = []
while boxes_scores_sorted:
# Select the box with highest score and remove it from the list
curr_box, curr_score = boxes_scores_sorted.pop(0)
result.append((curr_box, curr_score))
# Remove boxes with IoU > threshold
new_boxes = []
for box, score in boxes_scores_sorted:
if iou(curr_box, box) <= threshold:
new_boxes.append((box, score))
boxes_scores_sorted = new_boxes
return result
def read_boxes_from_file(filepath: str) -> List[Tuple[str, AnnotationRect, float]]:
"""
Reads a file containing bounding boxes and scores in the format:
{image_number} {x1} {y1} {x2} {y2} {score}
Returns a list of tuples: (image_number, x1, y1, x2, y2, score)
"""
boxes: List[Tuple[AnnotationRect, float]] = []
with open(filepath, "r") as f:
for line in f:
parts = line.strip().split()
if len(parts) != 6:
continue
img_id = parts[0]
x1, y1, x2, y2 = map(int, parts[1:5])
annotation_rect = AnnotationRect(x1, y1, x2, y2)
score = float(parts[5])
boxes.append((img_id, annotation_rect, score))
return boxes
def main():
boxes = read_boxes_from_file("mmp/a6/model_output.txt")
grouped = defaultdict(list)
for image_id, rect, score in boxes:
grouped[image_id].append((rect, score))
for image_id, rects_scores in grouped.items():
filtered_boxes = non_maximum_suppression(rects_scores, 0.3)
annotation_rects = [rect for rect, score in filtered_boxes if score > 0.5]
input_path = f".data/mmp-public-3.2/test/{image_id}.jpg"
output_path = f"mmp/a6/nms_output_{image_id}.png"
if not os.path.exists(input_path):
continue
draw_annotation_rects(
input_path,
annotation_rects,
rect_color=(255, 0, 0),
rect_width=2,
output_path=output_path,
)
if __name__ == "__main__":
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 359 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 231 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 232 KiB