update template code

This commit is contained in:
phatakmr
2025-10-13 14:48:00 +02:00
parent c9d159fcc6
commit 8f637a4a0d
46 changed files with 2955 additions and 1 deletions

52
tests/__init__.py Normal file
View File

@@ -0,0 +1,52 @@
import ast
from pathlib import Path
MMP_PARENTS = ["mmp"] + [f"a{i}" for i in range(1, 9)]
MMP_LOCAL_MODULES = [
"anchor_grid",
"annotation",
"bbr",
"dataset",
"evallib",
"label_grid",
"main",
"model",
"nms",
]
def check_bad_imports(file_path: Path | list[Path]):
if not isinstance(file_path, list):
file_path = [file_path]
bad_imports = []
for path in file_path:
with open(path, "r") as file:
tree = ast.parse(file.read(), filename=path)
absolute_imports = set()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for n in node.names:
for req in MMP_PARENTS:
if req in n.name:
absolute_imports.add(n.lineno)
for req in MMP_LOCAL_MODULES:
if req == n.name:
absolute_imports.add(n.lineno)
if isinstance(node, ast.ImportFrom):
if node.level == 0:
for req in MMP_PARENTS:
if req in node.module:
absolute_imports.add(node.lineno)
for req in MMP_LOCAL_MODULES:
if req == node.module:
absolute_imports.add(node.lineno)
if len(absolute_imports) > 0:
bad_imports.append((path, sorted(absolute_imports)))
if len(bad_imports) != 0:
message = "\n There are absolute imports in the following files:\n"
for path, line_numbers in bad_imports:
message += f"{path}: {', '.join(str(num) for num in line_numbers)}\n"
assert len(bad_imports) == 0, message

62
tests/test_a1.py Normal file
View File

@@ -0,0 +1,62 @@
from datetime import datetime
import pytest
from pathlib import Path
from tempfile import TemporaryDirectory
import numpy as np
from PIL import Image
import torch
import torchvision
from . import check_bad_imports
from mmp.a1 import main, tensors
current_assignment = pytest.mark.skipif(
not (datetime.now() <= datetime(2025, 10, 22, 23, 59, 59)),
reason="This is not the current assignment.",
)
@current_assignment
def test_no_abs_import():
paths = list(Path().glob("mmp/a1/*.py"))
check_bad_imports(paths)
@current_assignment
def test_main():
# for testing: generate random image
img1 = Image.fromarray((np.random.rand(128, 128, 3) * 255).astype(np.uint8))
img2 = Image.fromarray((np.random.rand(128, 128, 3) * 255).astype(np.uint8))
tfm = torchvision.transforms.Resize(320)
with TemporaryDirectory() as img_dir:
img_dir = Path(img_dir)
p1 = str(img_dir / "img1.jpg")
p2 = str(img_dir / "img2.jpg")
img1.save(p1)
img2.save(p2)
batch = main.build_batch([p1, p2])
batch_transformed = main.build_batch([p1, p2], transform=tfm)
assert isinstance(batch, torch.Tensor)
assert len(batch) == 2
model = main.get_model()
assert isinstance(model, torch.nn.Module)
main.main
@current_assignment
def test_tensors():
avg = tensors.avg_color(torch.rand(3, 128, 128))
assert isinstance(avg, torch.Tensor)
assert len(avg) == 3
masked = tensors.mask(
torch.rand(3, 128, 128), torch.rand(3, 128, 128), torch.rand(128, 128), 0.3
)
assert isinstance(masked, torch.Tensor)
assert masked.shape == (3, 128, 128)
result = tensors.add_matrix_vector(torch.rand(3, 4), torch.rand(3))
assert isinstance(result, torch.Tensor)
assert result.shape == (3, 4)

42
tests/test_a2.py Normal file
View File

@@ -0,0 +1,42 @@
from datetime import datetime
import pytest
from pathlib import Path
import torch
import os
from . import check_bad_imports
from mmp.a2 import main
current_assignment = pytest.mark.skipif(
not (datetime(2025, 10, 23) <= datetime.now() <= datetime(2025, 10, 29, 23, 59, 59)),
reason="This is not the current assignment.",
)
@current_assignment
def test_no_abs_import():
paths = list(Path().glob("mmp/a2/*.py"))
check_bad_imports(paths)
@current_assignment
def test_main():
assert issubclass(main.MmpNet, torch.nn.Module)
net = main.MmpNet(12)
assert isinstance(net(torch.rand(2, 3, 128, 128)), torch.Tensor)
loader = main.get_dataloader(
is_train=False,
data_root=os.path.join(os.environ["TORCH_HOME"], "datasets"),
batch_size=2,
num_workers=0,
)
x = next(iter(loader))
crit, opt = main.get_criterion_optimizer(net)
assert isinstance(
crit(torch.rand(10, 4), (torch.rand(10) * 2).long()), torch.Tensor
)
assert isinstance(opt, torch.optim.Optimizer)
main.train_epoch
main.eval_epoch
main.main

49
tests/test_a3.py Normal file
View File

@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from pathlib import Path
import torch
import pytest
from . import check_bad_imports
current_assignment = pytest.mark.skipif(
not (datetime(2025, 10, 30) <= datetime.now() <= datetime(2025, 11, 5, 23, 59, 59)),
reason="This is not the current assignment.",
)
@current_assignment
def test_no_abs_import():
paths = list(Path().glob("mmp/a3/*.py"))
check_bad_imports(paths)
@current_assignment
def test_dataset():
from mmp.a3 import dataset
assert issubclass(dataset.MMP_Dataset, torch.utils.data.Dataset)
@current_assignment
def test_annotation():
from mmp.a3 import annotation
rect = annotation.AnnotationRect(x1=3, y1=2, x2=40, y2=50)
assert rect.x1 == 3
assert rect.y1 == 2
assert rect.x2 == 40
assert rect.y2 == 50
rect = annotation.AnnotationRect(x1=3, y1=2, x2=40, y2=50)
assert rect.x1 == 3
assert rect.y1 == 2
assert rect.x2 == 40
assert rect.y2 == 50
@current_assignment
def test_main():
from mmp.a3 import main
main.main

53
tests/test_a4.py Normal file
View File

@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from pathlib import Path
import pytest
from . import check_bad_imports
from mmp.a4 import anchor_grid, label_grid, dataset
current_assignment = pytest.mark.skipif(
not (datetime(2025, 11, 6) <= datetime.now() <= datetime(2025, 11, 12, 23, 59, 59)),
reason="This is not the current assignment.",
)
@current_assignment
def test_no_abs_import():
paths = list(Path().glob("mmp/a4/*.py"))
check_bad_imports(paths)
@current_assignment
def test_anchor_grid():
grid = anchor_grid.get_anchor_grid(
num_rows=4,
num_cols=5,
scale_factor=12.0,
anchor_widths=[30.0, 80.0],
aspect_ratios=[0.5, 1.0],
)
assert grid.ndim == 5
assert grid.shape[-1] == 4
@current_assignment
def test_iou():
label_grid.iou
@current_assignment
def test_label_grid():
lg, *_ = label_grid.get_label_grid(
anchor_grid=anchor_grid.get_anchor_grid(2, 2, 10, [10.0], [1.0]),
gts=[],
min_iou=0.5,
)
assert lg.dtype == bool
@current_assignment
def test_dataset():
dataset.MMP_Dataset
dataset.calculate_max_coverage

34
tests/test_a5.py Normal file
View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from pathlib import Path
import pytest
import torch
from . import check_bad_imports
from mmp.a5 import model, main
current_assignment = pytest.mark.skipif(
not (datetime(2025, 11, 13) <= datetime.now() <= datetime(2025, 11, 26, 23, 59, 59)),
reason="This is not the current assignment.",
)
@current_assignment
def test_no_abs_import():
paths = list(Path().glob("mmp/a5/*.py"))
check_bad_imports(paths)
@current_assignment
def test_model():
net = model.MmpNet(num_widths=4, num_aspect_ratios=2)
assert isinstance(net, torch.nn.Module)
output = net(torch.rand(2, 3, 224, 224))
@current_assignment
def test_main():
main.get_random_sampling_mask
main.main
main.step

32
tests/test_a6.py Normal file
View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from pathlib import Path
import pytest
from . import check_bad_imports
from mmp.a6 import nms, main
current_assignment = pytest.mark.skipif(
not (datetime(2025, 11, 27) <= datetime.now() <= datetime(2025, 12, 10, 23, 59, 59)),
reason="This is not the current assignment.",
)
@current_assignment
def test_no_abs_import():
paths = list(Path().glob("mmp/a6/*.py"))
check_bad_imports(paths)
@current_assignment
def test_nms():
nms.non_maximum_suppression
@current_assignment
def test_main():
main.evaluate_test
main.evaluate
main.batch_inference

31
tests/test_a7.py Normal file
View File

@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from pathlib import Path
import pytest
from . import check_bad_imports
from mmp.a7 import main, dataset
current_assignment = pytest.mark.skipif(
not (datetime(2025, 12, 11) <= datetime.now() <= datetime(2025, 12, 17, 23, 59, 59)),
reason="This is not the current assignment.",
)
@current_assignment
def test_no_abs_import():
paths = list(Path().glob("mmp/a7/*.py"))
check_bad_imports(paths)
@current_assignment
def test_main():
main.main
@current_assignment
def test_dataset():
dataset.MMP_Dataset
dataset.get_dataloader

36
tests/test_a8.py Normal file
View File

@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from pathlib import Path
import pytest
import torch
from . import check_bad_imports
from mmp.a8 import bbr
current_assignment = pytest.mark.skipif(
not (datetime(2025, 12, 18) <= datetime.now()),
reason="This is not the current assignment.",
)
@current_assignment
def test_no_abs_import():
paths = list(Path().glob("mmp/a8/*.py"))
check_bad_imports(paths)
@current_assignment
def test_bbr():
anchor_boxes = torch.tensor([10, 100, 10, 100]).unsqueeze(0)
adjustments = torch.tensor([0.2, 0.2, 0.2, 0.2]).unsqueeze(0)
groundtruths = torch.tensor([20, 100, 20, 90]).unsqueeze(0)
loss = bbr.get_bbr_loss(anchor_boxes, adjustments, groundtruths)
assert isinstance(loss, torch.Tensor)
assert loss.ndim == 0
anchor_pos = torch.tensor([10, 100, 10, 100])
adj = torch.tensor([0.2, 0.2, 0.2, 0.2])
rect = bbr.apply_bbr(anchor_pos, adj)
assert isinstance(rect.x1, float)