37 lines
978 B
Python
37 lines
978 B
Python
|
|
# -*- coding: utf-8 -*-
|
||
|
|
|
||
|
|
from datetime import datetime
|
||
|
|
from pathlib import Path
|
||
|
|
import pytest
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from . import check_bad_imports
|
||
|
|
from mmp.a8 import bbr
|
||
|
|
|
||
|
|
|
||
|
|
current_assignment = pytest.mark.skipif(
|
||
|
|
not (datetime(2025, 12, 18) <= datetime.now()),
|
||
|
|
reason="This is not the current assignment.",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@current_assignment
|
||
|
|
def test_no_abs_import():
|
||
|
|
paths = list(Path().glob("mmp/a8/*.py"))
|
||
|
|
check_bad_imports(paths)
|
||
|
|
|
||
|
|
|
||
|
|
@current_assignment
|
||
|
|
def test_bbr():
|
||
|
|
anchor_boxes = torch.tensor([10, 100, 10, 100]).unsqueeze(0)
|
||
|
|
adjustments = torch.tensor([0.2, 0.2, 0.2, 0.2]).unsqueeze(0)
|
||
|
|
groundtruths = torch.tensor([20, 100, 20, 90]).unsqueeze(0)
|
||
|
|
loss = bbr.get_bbr_loss(anchor_boxes, adjustments, groundtruths)
|
||
|
|
assert isinstance(loss, torch.Tensor)
|
||
|
|
assert loss.ndim == 0
|
||
|
|
|
||
|
|
anchor_pos = torch.tensor([10, 100, 10, 100])
|
||
|
|
adj = torch.tensor([0.2, 0.2, 0.2, 0.2])
|
||
|
|
rect = bbr.apply_bbr(anchor_pos, adj)
|
||
|
|
assert isinstance(rect.x1, float)
|