27 lines
808 B
Python
27 lines
808 B
Python
import torch
|
|
import numpy as np
|
|
|
|
from ..a3.annotation import AnnotationRect
|
|
|
|
|
|
def get_bbr_loss(
|
|
anchor_boxes: torch.Tensor,
|
|
adjustments: torch.Tensor,
|
|
groundtruths: torch.Tensor,
|
|
):
|
|
"""
|
|
@param anchor_boxes: Batch of box coordinates from the anchor grid
|
|
@param adjustments: Batch of adjustments of the prediction (#data, 4)
|
|
@param groundtruths: Batch of ground truth data given as (x1, y1, x2, y2) (#data, 4)
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
def apply_bbr(anchor_box: np.ndarray, adjustment: torch.Tensor) -> AnnotationRect:
|
|
"""Calculates an AnnotationRect based on a given anchor box and adjustments
|
|
|
|
@param anchor_box: Single box coordinates from the anchor grid
|
|
@param adjustment: Adjustments, generated by the model
|
|
"""
|
|
raise NotImplementedError()
|