update template code
This commit is contained in:
52
tests/__init__.py
Normal file
52
tests/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
MMP_PARENTS = ["mmp"] + [f"a{i}" for i in range(1, 9)]
|
||||
MMP_LOCAL_MODULES = [
|
||||
"anchor_grid",
|
||||
"annotation",
|
||||
"bbr",
|
||||
"dataset",
|
||||
"evallib",
|
||||
"label_grid",
|
||||
"main",
|
||||
"model",
|
||||
"nms",
|
||||
]
|
||||
|
||||
|
||||
def check_bad_imports(file_path: Path | list[Path]):
|
||||
if not isinstance(file_path, list):
|
||||
file_path = [file_path]
|
||||
bad_imports = []
|
||||
|
||||
for path in file_path:
|
||||
with open(path, "r") as file:
|
||||
tree = ast.parse(file.read(), filename=path)
|
||||
|
||||
absolute_imports = set()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for n in node.names:
|
||||
for req in MMP_PARENTS:
|
||||
if req in n.name:
|
||||
absolute_imports.add(n.lineno)
|
||||
for req in MMP_LOCAL_MODULES:
|
||||
if req == n.name:
|
||||
absolute_imports.add(n.lineno)
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
if node.level == 0:
|
||||
for req in MMP_PARENTS:
|
||||
if req in node.module:
|
||||
absolute_imports.add(node.lineno)
|
||||
for req in MMP_LOCAL_MODULES:
|
||||
if req == node.module:
|
||||
absolute_imports.add(node.lineno)
|
||||
if len(absolute_imports) > 0:
|
||||
bad_imports.append((path, sorted(absolute_imports)))
|
||||
if len(bad_imports) != 0:
|
||||
message = "\n There are absolute imports in the following files:\n"
|
||||
for path, line_numbers in bad_imports:
|
||||
message += f"{path}: {', '.join(str(num) for num in line_numbers)}\n"
|
||||
assert len(bad_imports) == 0, message
|
||||
|
||||
62
tests/test_a1.py
Normal file
62
tests/test_a1.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from . import check_bad_imports
|
||||
from mmp.a1 import main, tensors
|
||||
|
||||
current_assignment = pytest.mark.skipif(
|
||||
not (datetime.now() <= datetime(2025, 10, 22, 23, 59, 59)),
|
||||
reason="This is not the current assignment.",
|
||||
)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_no_abs_import():
|
||||
paths = list(Path().glob("mmp/a1/*.py"))
|
||||
check_bad_imports(paths)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_main():
|
||||
# for testing: generate random image
|
||||
img1 = Image.fromarray((np.random.rand(128, 128, 3) * 255).astype(np.uint8))
|
||||
img2 = Image.fromarray((np.random.rand(128, 128, 3) * 255).astype(np.uint8))
|
||||
tfm = torchvision.transforms.Resize(320)
|
||||
with TemporaryDirectory() as img_dir:
|
||||
img_dir = Path(img_dir)
|
||||
p1 = str(img_dir / "img1.jpg")
|
||||
p2 = str(img_dir / "img2.jpg")
|
||||
img1.save(p1)
|
||||
img2.save(p2)
|
||||
batch = main.build_batch([p1, p2])
|
||||
batch_transformed = main.build_batch([p1, p2], transform=tfm)
|
||||
assert isinstance(batch, torch.Tensor)
|
||||
assert len(batch) == 2
|
||||
|
||||
model = main.get_model()
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
|
||||
main.main
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_tensors():
|
||||
avg = tensors.avg_color(torch.rand(3, 128, 128))
|
||||
assert isinstance(avg, torch.Tensor)
|
||||
assert len(avg) == 3
|
||||
|
||||
masked = tensors.mask(
|
||||
torch.rand(3, 128, 128), torch.rand(3, 128, 128), torch.rand(128, 128), 0.3
|
||||
)
|
||||
assert isinstance(masked, torch.Tensor)
|
||||
assert masked.shape == (3, 128, 128)
|
||||
|
||||
result = tensors.add_matrix_vector(torch.rand(3, 4), torch.rand(3))
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.shape == (3, 4)
|
||||
42
tests/test_a2.py
Normal file
42
tests/test_a2.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import os
|
||||
|
||||
from . import check_bad_imports
|
||||
from mmp.a2 import main
|
||||
|
||||
current_assignment = pytest.mark.skipif(
|
||||
not (datetime(2025, 10, 23) <= datetime.now() <= datetime(2025, 10, 29, 23, 59, 59)),
|
||||
reason="This is not the current assignment.",
|
||||
)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_no_abs_import():
|
||||
paths = list(Path().glob("mmp/a2/*.py"))
|
||||
check_bad_imports(paths)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_main():
|
||||
assert issubclass(main.MmpNet, torch.nn.Module)
|
||||
net = main.MmpNet(12)
|
||||
assert isinstance(net(torch.rand(2, 3, 128, 128)), torch.Tensor)
|
||||
|
||||
loader = main.get_dataloader(
|
||||
is_train=False,
|
||||
data_root=os.path.join(os.environ["TORCH_HOME"], "datasets"),
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
)
|
||||
x = next(iter(loader))
|
||||
crit, opt = main.get_criterion_optimizer(net)
|
||||
assert isinstance(
|
||||
crit(torch.rand(10, 4), (torch.rand(10) * 2).long()), torch.Tensor
|
||||
)
|
||||
assert isinstance(opt, torch.optim.Optimizer)
|
||||
main.train_epoch
|
||||
main.eval_epoch
|
||||
main.main
|
||||
49
tests/test_a3.py
Normal file
49
tests/test_a3.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from . import check_bad_imports
|
||||
|
||||
|
||||
current_assignment = pytest.mark.skipif(
|
||||
not (datetime(2025, 10, 30) <= datetime.now() <= datetime(2025, 11, 5, 23, 59, 59)),
|
||||
reason="This is not the current assignment.",
|
||||
)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_no_abs_import():
|
||||
paths = list(Path().glob("mmp/a3/*.py"))
|
||||
check_bad_imports(paths)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_dataset():
|
||||
from mmp.a3 import dataset
|
||||
|
||||
assert issubclass(dataset.MMP_Dataset, torch.utils.data.Dataset)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_annotation():
|
||||
from mmp.a3 import annotation
|
||||
|
||||
rect = annotation.AnnotationRect(x1=3, y1=2, x2=40, y2=50)
|
||||
assert rect.x1 == 3
|
||||
assert rect.y1 == 2
|
||||
assert rect.x2 == 40
|
||||
assert rect.y2 == 50
|
||||
rect = annotation.AnnotationRect(x1=3, y1=2, x2=40, y2=50)
|
||||
assert rect.x1 == 3
|
||||
assert rect.y1 == 2
|
||||
assert rect.x2 == 40
|
||||
assert rect.y2 == 50
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_main():
|
||||
from mmp.a3 import main
|
||||
|
||||
main.main
|
||||
53
tests/test_a4.py
Normal file
53
tests/test_a4.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from . import check_bad_imports
|
||||
from mmp.a4 import anchor_grid, label_grid, dataset
|
||||
|
||||
current_assignment = pytest.mark.skipif(
|
||||
not (datetime(2025, 11, 6) <= datetime.now() <= datetime(2025, 11, 12, 23, 59, 59)),
|
||||
reason="This is not the current assignment.",
|
||||
)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_no_abs_import():
|
||||
paths = list(Path().glob("mmp/a4/*.py"))
|
||||
check_bad_imports(paths)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_anchor_grid():
|
||||
grid = anchor_grid.get_anchor_grid(
|
||||
num_rows=4,
|
||||
num_cols=5,
|
||||
scale_factor=12.0,
|
||||
anchor_widths=[30.0, 80.0],
|
||||
aspect_ratios=[0.5, 1.0],
|
||||
)
|
||||
assert grid.ndim == 5
|
||||
assert grid.shape[-1] == 4
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_iou():
|
||||
label_grid.iou
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_label_grid():
|
||||
lg, *_ = label_grid.get_label_grid(
|
||||
anchor_grid=anchor_grid.get_anchor_grid(2, 2, 10, [10.0], [1.0]),
|
||||
gts=[],
|
||||
min_iou=0.5,
|
||||
)
|
||||
assert lg.dtype == bool
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_dataset():
|
||||
dataset.MMP_Dataset
|
||||
dataset.calculate_max_coverage
|
||||
34
tests/test_a5.py
Normal file
34
tests/test_a5.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from . import check_bad_imports
|
||||
from mmp.a5 import model, main
|
||||
|
||||
current_assignment = pytest.mark.skipif(
|
||||
not (datetime(2025, 11, 13) <= datetime.now() <= datetime(2025, 11, 26, 23, 59, 59)),
|
||||
reason="This is not the current assignment.",
|
||||
)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_no_abs_import():
|
||||
paths = list(Path().glob("mmp/a5/*.py"))
|
||||
check_bad_imports(paths)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_model():
|
||||
net = model.MmpNet(num_widths=4, num_aspect_ratios=2)
|
||||
assert isinstance(net, torch.nn.Module)
|
||||
output = net(torch.rand(2, 3, 224, 224))
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_main():
|
||||
main.get_random_sampling_mask
|
||||
main.main
|
||||
main.step
|
||||
32
tests/test_a6.py
Normal file
32
tests/test_a6.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from . import check_bad_imports
|
||||
from mmp.a6 import nms, main
|
||||
|
||||
|
||||
current_assignment = pytest.mark.skipif(
|
||||
not (datetime(2025, 11, 27) <= datetime.now() <= datetime(2025, 12, 10, 23, 59, 59)),
|
||||
reason="This is not the current assignment.",
|
||||
)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_no_abs_import():
|
||||
paths = list(Path().glob("mmp/a6/*.py"))
|
||||
check_bad_imports(paths)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_nms():
|
||||
nms.non_maximum_suppression
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_main():
|
||||
main.evaluate_test
|
||||
main.evaluate
|
||||
main.batch_inference
|
||||
31
tests/test_a7.py
Normal file
31
tests/test_a7.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from . import check_bad_imports
|
||||
from mmp.a7 import main, dataset
|
||||
|
||||
|
||||
current_assignment = pytest.mark.skipif(
|
||||
not (datetime(2025, 12, 11) <= datetime.now() <= datetime(2025, 12, 17, 23, 59, 59)),
|
||||
reason="This is not the current assignment.",
|
||||
)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_no_abs_import():
|
||||
paths = list(Path().glob("mmp/a7/*.py"))
|
||||
check_bad_imports(paths)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_main():
|
||||
main.main
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_dataset():
|
||||
dataset.MMP_Dataset
|
||||
dataset.get_dataloader
|
||||
36
tests/test_a8.py
Normal file
36
tests/test_a8.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from . import check_bad_imports
|
||||
from mmp.a8 import bbr
|
||||
|
||||
|
||||
current_assignment = pytest.mark.skipif(
|
||||
not (datetime(2025, 12, 18) <= datetime.now()),
|
||||
reason="This is not the current assignment.",
|
||||
)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_no_abs_import():
|
||||
paths = list(Path().glob("mmp/a8/*.py"))
|
||||
check_bad_imports(paths)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_bbr():
|
||||
anchor_boxes = torch.tensor([10, 100, 10, 100]).unsqueeze(0)
|
||||
adjustments = torch.tensor([0.2, 0.2, 0.2, 0.2]).unsqueeze(0)
|
||||
groundtruths = torch.tensor([20, 100, 20, 90]).unsqueeze(0)
|
||||
loss = bbr.get_bbr_loss(anchor_boxes, adjustments, groundtruths)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.ndim == 0
|
||||
|
||||
anchor_pos = torch.tensor([10, 100, 10, 100])
|
||||
adj = torch.tensor([0.2, 0.2, 0.2, 0.2])
|
||||
rect = bbr.apply_bbr(anchor_pos, adj)
|
||||
assert isinstance(rect.x1, float)
|
||||
Reference in New Issue
Block a user