adds nms and eval
This commit is contained in:
@@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from ..a3.annotation import read_groundtruth_file, AnnotationRect
|
from ..a3.annotation import read_groundtruth_file, AnnotationRect
|
||||||
from .label_grid import get_label_grid, iou
|
from .label_grid import get_label_grid
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib.patches as patches
|
import matplotlib.patches as patches
|
||||||
from .anchor_grid import get_anchor_grid
|
from .anchor_grid import get_anchor_grid
|
||||||
@@ -178,7 +178,6 @@ def compute_ious_vectorized(boxes1, boxes2):
|
|||||||
boxes1: (M, 4), boxes2: (N, 4) -- format [x1, y1, x2, y2]
|
boxes1: (M, 4), boxes2: (N, 4) -- format [x1, y1, x2, y2]
|
||||||
Returns: (M, N) IoU
|
Returns: (M, N) IoU
|
||||||
"""
|
"""
|
||||||
M, N = boxes1.shape[0], boxes2.shape[0]
|
|
||||||
|
|
||||||
# Expand to (M, N, 4)
|
# Expand to (M, N, 4)
|
||||||
boxes1 = boxes1[:, None, :] # (M, 1, 4)
|
boxes1 = boxes1[:, None, :] # (M, 1, 4)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import torch
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch import Tensor
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
@@ -11,6 +10,7 @@ from .model import MmpNet
|
|||||||
from ..a4.anchor_grid import get_anchor_grid
|
from ..a4.anchor_grid import get_anchor_grid
|
||||||
from ..a4.dataset import get_dataloader
|
from ..a4.dataset import get_dataloader
|
||||||
from ..a2.main import get_criterion_optimizer
|
from ..a2.main import get_criterion_optimizer
|
||||||
|
from ..a6.main import evaluate as evaluate_v2
|
||||||
|
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
@@ -65,44 +65,11 @@ def get_random_sampling_mask(labels: torch.Tensor, neg_ratio: float) -> torch.Te
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def get_detection_metrics(
|
|
||||||
output: Tensor, labels: torch.Tensor, threshold: float
|
|
||||||
) -> tuple[float, float, float, float]:
|
|
||||||
"""
|
|
||||||
Returns precision, recall, f1 for the positive (human) class, and overall accuracy.
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
probs = torch.softmax(output, dim=-1)[..., 1]
|
|
||||||
|
|
||||||
preds = probs >= threshold
|
|
||||||
|
|
||||||
TP = ((preds == 1) & (labels == 1)).sum().item()
|
|
||||||
FP = ((preds == 1) & (labels == 0)).sum().item()
|
|
||||||
FN = ((preds == 0) & (labels == 1)).sum().item()
|
|
||||||
TN = ((preds == 0) & (labels == 0)).sum().item()
|
|
||||||
|
|
||||||
precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
|
|
||||||
recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
|
|
||||||
f1 = (
|
|
||||||
2 * precision * recall / (precision + recall)
|
|
||||||
if (precision + recall) > 0
|
|
||||||
else 0.0
|
|
||||||
)
|
|
||||||
accuracy = (TP + TN) / (TP + TN + FP + FN) if (TP + TN + FP + FN) > 0 else 0.0
|
|
||||||
|
|
||||||
return (
|
|
||||||
precision,
|
|
||||||
recall,
|
|
||||||
f1,
|
|
||||||
accuracy,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
model: MmpNet,
|
model: MmpNet,
|
||||||
criterion,
|
criterion,
|
||||||
dataloader: DataLoader,
|
dataloader: DataLoader,
|
||||||
) -> tuple[float, float, float, float]:
|
) -> float:
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
model.eval()
|
model.eval()
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
@@ -123,15 +90,7 @@ def evaluate(
|
|||||||
all_outputs.append(outputs.cpu())
|
all_outputs.append(outputs.cpu())
|
||||||
all_labels.append(lbl_batch.cpu())
|
all_labels.append(lbl_batch.cpu())
|
||||||
avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
|
avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
|
||||||
if all_outputs and all_labels:
|
return avg_loss
|
||||||
outputs_cat = torch.cat(all_outputs)
|
|
||||||
labels_cat = torch.cat(all_labels)
|
|
||||||
precision, recall, f1, acc = get_detection_metrics(
|
|
||||||
outputs_cat, labels_cat, threshold=0.5
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
precision = recall = f1 = 0.0
|
|
||||||
return avg_loss, precision, recall, f1, acc
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@@ -243,7 +202,7 @@ def main():
|
|||||||
_, optimizer = get_criterion_optimizer(model=model)
|
_, optimizer = get_criterion_optimizer(model=model)
|
||||||
criterion = NegativeMiningCriterion(enable_negative_mining=True)
|
criterion = NegativeMiningCriterion(enable_negative_mining=True)
|
||||||
criterion_eval = NegativeMiningCriterion(enable_negative_mining=False)
|
criterion_eval = NegativeMiningCriterion(enable_negative_mining=False)
|
||||||
num_epochs = 10
|
num_epochs = 5
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
train_loss = train(
|
train_loss = train(
|
||||||
@@ -252,17 +211,16 @@ def main():
|
|||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
avg_loss, precision, recall, f1, acc = evaluate(
|
avg_loss = evaluate(
|
||||||
model=model, criterion=criterion_eval, dataloader=dataloader_val
|
model=model, criterion=criterion_eval, dataloader=dataloader_val
|
||||||
)
|
)
|
||||||
|
_ = evaluate_v2(
|
||||||
|
model=model, device=device, anchor_grid=anchor_grid, loader=dataloader_train
|
||||||
|
)
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
writer.add_scalar("Loss/train_epoch", train_loss, epoch)
|
writer.add_scalar("Loss/train_epoch", train_loss, epoch)
|
||||||
writer.add_scalar("Loss/eval_epoch", avg_loss, epoch)
|
writer.add_scalar("Loss/eval_epoch", avg_loss, epoch)
|
||||||
writer.add_scalar("Acc/precision", precision, epoch)
|
|
||||||
writer.add_scalar("Acc/recall", recall, epoch)
|
|
||||||
writer.add_scalar("Acc/acc", acc, epoch)
|
|
||||||
writer.add_scalar("Acc/f1", f1, epoch)
|
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|||||||
0
mmp/a6/__init__.py
Normal file
0
mmp/a6/__init__.py
Normal file
127
mmp/a6/main.py
127
mmp/a6/main.py
@@ -1,25 +1,144 @@
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from mmp.a6.evallib import calculate_ap_pr
|
||||||
|
from ..a4.label_grid import iou
|
||||||
|
|
||||||
from ..a5.model import MmpNet
|
from ..a5.model import MmpNet
|
||||||
from ..a3.annotation import AnnotationRect
|
from ..a3.annotation import AnnotationRect, read_groundtruth_file
|
||||||
|
|
||||||
|
from .nms import non_maximum_suppression
|
||||||
|
|
||||||
|
|
||||||
def batch_inference(
|
def batch_inference(
|
||||||
model: MmpNet, images: torch.Tensor, device: torch.device, anchor_grid: np.ndarray
|
model: MmpNet, images: torch.Tensor, device: torch.device, anchor_grid: np.ndarray
|
||||||
) -> List[List[Tuple[AnnotationRect, float]]]:
|
) -> List[List[Tuple[AnnotationRect, float]]]:
|
||||||
raise NotImplementedError()
|
score_thresh = 0.5
|
||||||
|
nms_thresh = 0.3
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
images = images.to(device)
|
||||||
|
anchor_grid = anchor_grid # shape [W, R, h, w, 4]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(images) # (B, W, R, h, w, 2)
|
||||||
|
probs = torch.softmax(outputs, dim=-1)[..., 1] # (B, W, R, h, w)
|
||||||
|
probs_np = probs.cpu().numpy()
|
||||||
|
|
||||||
|
batch_size = outputs.shape[0]
|
||||||
|
for b in range(batch_size):
|
||||||
|
detections = []
|
||||||
|
for idx in np.ndindex(anchor_grid.shape[:-1]):
|
||||||
|
score = probs_np[b][idx]
|
||||||
|
# if score >= score_thresh:
|
||||||
|
box = anchor_grid[idx]
|
||||||
|
rect = AnnotationRect.fromarray(box)
|
||||||
|
detections.append((rect, float(score)))
|
||||||
|
detections_nms = non_maximum_suppression(detections, nms_thresh)
|
||||||
|
results.append(detections_nms)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def evaluate() -> float: # feel free to change the arguments
|
def evaluate(
|
||||||
|
model: MmpNet, loader: DataLoader, device: torch.device, anchor_grid: np.ndarray
|
||||||
|
) -> float:
|
||||||
"""Evaluates a specified model on the whole validation dataset.
|
"""Evaluates a specified model on the whole validation dataset.
|
||||||
|
|
||||||
@return: AP for the validation set as a float.
|
@return: AP for the validation set as a float.
|
||||||
|
|
||||||
You decide which arguments this function should receive
|
You decide which arguments this function should receive
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
|
||||||
|
path_to_data = ".data/mmp-public-3.2/train"
|
||||||
|
|
||||||
|
progress_bar = tqdm(loader, desc="Evaluation", unit="batch")
|
||||||
|
image_count = 0
|
||||||
|
ap_total = 0
|
||||||
|
for img_batch, _, id_batch in progress_bar:
|
||||||
|
inference = batch_inference(
|
||||||
|
anchor_grid=anchor_grid, device=device, images=img_batch, model=model
|
||||||
|
)
|
||||||
|
gts = get_gts_for_batch(id_batch=id_batch, gt_base_path=path_to_data)
|
||||||
|
|
||||||
|
dict_detections = {
|
||||||
|
img_id.item(): inference[idx] for idx, img_id in enumerate(id_batch)
|
||||||
|
}
|
||||||
|
dict_gt = {img_id.item(): gts[idx] for idx, img_id in enumerate(id_batch)}
|
||||||
|
average_prevision, precision, recall = calculate_ap_pr(dict_detections, dict_gt)
|
||||||
|
ap_total = (ap_total * image_count + average_prevision) / (
|
||||||
|
image_count + id_batch.shape[0]
|
||||||
|
)
|
||||||
|
image_count += id_batch.shape[0]
|
||||||
|
|
||||||
|
progress_bar.set_postfix(
|
||||||
|
{
|
||||||
|
"ap": ap_total,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ap_total
|
||||||
|
|
||||||
|
|
||||||
|
def get_gts_for_batch(
|
||||||
|
id_batch: torch.Tensor, gt_base_path: str
|
||||||
|
) -> List[List[AnnotationRect]]:
|
||||||
|
return [
|
||||||
|
read_groundtruth_file(
|
||||||
|
os.path.join(gt_base_path, f"{str(img_id.item()).zfill(8)}.gt_data.txt")
|
||||||
|
)
|
||||||
|
for img_id in id_batch
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def calc_tp_fp_fn(
|
||||||
|
detections: List[Tuple[AnnotationRect, float]],
|
||||||
|
gts: List[AnnotationRect],
|
||||||
|
iou_threshold: float = 0.5,
|
||||||
|
confidence_threshhold: float = 0.5,
|
||||||
|
) -> tuple[int, int, int]:
|
||||||
|
"""
|
||||||
|
Calculates precision and recall for object detection results on a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detections: List of (AnnotationRect, confidence) tuples representing predicted boxes and scores. Should be sorted by descending confidence.
|
||||||
|
gts: List of AnnotationRect for ground truth.
|
||||||
|
iou_threshold: Minimum IoU to consider a detection a true positive.
|
||||||
|
confidence_threshhold: Minimum confidence required to include a detection.
|
||||||
|
Returns:
|
||||||
|
num_tp: Number of true positives (int).
|
||||||
|
num_fp: Number of false positives (int).
|
||||||
|
num_fn: Number of false negatives (int).
|
||||||
|
"""
|
||||||
|
detections = [det for det in detections if det[1] >= confidence_threshhold]
|
||||||
|
detections.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
matches = set()
|
||||||
|
fp = 0
|
||||||
|
tp = 0
|
||||||
|
|
||||||
|
for det_rect, _ in detections:
|
||||||
|
iou_map = [iou(det_rect, gt_rect) for gt_rect in gts]
|
||||||
|
if len(iou_map) == 0:
|
||||||
|
fp += 1
|
||||||
|
continue
|
||||||
|
max_idx = np.argmax(iou_map)
|
||||||
|
if max_idx in matches or iou_map[max_idx] < iou_threshold:
|
||||||
|
fp += 1
|
||||||
|
continue
|
||||||
|
matches.add(max_idx)
|
||||||
|
tp += 1
|
||||||
|
|
||||||
|
fn = len(gts) - len(matches)
|
||||||
|
|
||||||
|
return tp, fp, fn
|
||||||
|
|
||||||
|
|
||||||
def evaluate_test(): # feel free to change the arguments
|
def evaluate_test(): # feel free to change the arguments
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,9 @@
|
|||||||
|
import os
|
||||||
from typing import List, Sequence, Tuple
|
from typing import List, Sequence, Tuple
|
||||||
|
|
||||||
from ..a3.annotation import AnnotationRect
|
from ..a3.annotation import AnnotationRect
|
||||||
|
from ..a4.label_grid import iou, draw_annotation_rects
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
def non_maximum_suppression(
|
def non_maximum_suppression(
|
||||||
@@ -12,4 +15,68 @@ def non_maximum_suppression(
|
|||||||
|
|
||||||
@return: A list of tuples of the remaining boxes after NMS together with their scores
|
@return: A list of tuples of the remaining boxes after NMS together with their scores
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
if not boxes_scores:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sort the boxes by score in descending order
|
||||||
|
boxes_scores_sorted = sorted(boxes_scores, key=lambda bs: bs[1], reverse=True)
|
||||||
|
result = []
|
||||||
|
while boxes_scores_sorted:
|
||||||
|
# Select the box with highest score and remove it from the list
|
||||||
|
curr_box, curr_score = boxes_scores_sorted.pop(0)
|
||||||
|
result.append((curr_box, curr_score))
|
||||||
|
# Remove boxes with IoU > threshold
|
||||||
|
new_boxes = []
|
||||||
|
for box, score in boxes_scores_sorted:
|
||||||
|
if iou(curr_box, box) <= threshold:
|
||||||
|
new_boxes.append((box, score))
|
||||||
|
boxes_scores_sorted = new_boxes
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def read_boxes_from_file(filepath: str) -> List[Tuple[str, AnnotationRect, float]]:
|
||||||
|
"""
|
||||||
|
Reads a file containing bounding boxes and scores in the format:
|
||||||
|
{image_number} {x1} {y1} {x2} {y2} {score}
|
||||||
|
Returns a list of tuples: (image_number, x1, y1, x2, y2, score)
|
||||||
|
"""
|
||||||
|
boxes: List[Tuple[AnnotationRect, float]] = []
|
||||||
|
with open(filepath, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) != 6:
|
||||||
|
continue
|
||||||
|
img_id = parts[0]
|
||||||
|
x1, y1, x2, y2 = map(int, parts[1:5])
|
||||||
|
annotation_rect = AnnotationRect(x1, y1, x2, y2)
|
||||||
|
score = float(parts[5])
|
||||||
|
boxes.append((img_id, annotation_rect, score))
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
boxes = read_boxes_from_file("mmp/a6/model_output.txt")
|
||||||
|
|
||||||
|
grouped = defaultdict(list)
|
||||||
|
for image_id, rect, score in boxes:
|
||||||
|
grouped[image_id].append((rect, score))
|
||||||
|
|
||||||
|
for image_id, rects_scores in grouped.items():
|
||||||
|
filtered_boxes = non_maximum_suppression(rects_scores, 0.3)
|
||||||
|
annotation_rects = [rect for rect, score in filtered_boxes if score > 0.5]
|
||||||
|
input_path = f".data/mmp-public-3.2/test/{image_id}.jpg"
|
||||||
|
output_path = f"mmp/a6/nms_output_{image_id}.png"
|
||||||
|
if not os.path.exists(input_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
draw_annotation_rects(
|
||||||
|
input_path,
|
||||||
|
annotation_rects,
|
||||||
|
rect_color=(255, 0, 0),
|
||||||
|
rect_width=2,
|
||||||
|
output_path=output_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
BIN
mmp/a6/nms_output_02247421.png
Normal file
BIN
mmp/a6/nms_output_02247421.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 359 KiB |
BIN
mmp/a6/nms_output_02249576.png
Normal file
BIN
mmp/a6/nms_output_02249576.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 231 KiB |
BIN
mmp/a6/nms_output_02249614.png
Normal file
BIN
mmp/a6/nms_output_02249614.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 232 KiB |
Reference in New Issue
Block a user