Files
mmp_wise2526_franksim/tests/test_a2.py

43 lines
1.1 KiB
Python
Raw Permalink Normal View History

2025-10-13 14:48:00 +02:00
from datetime import datetime
import pytest
from pathlib import Path
import torch
import os
from . import check_bad_imports
from mmp.a2 import main
current_assignment = pytest.mark.skipif(
not (datetime(2025, 10, 23) <= datetime.now() <= datetime(2025, 10, 29, 23, 59, 59)),
reason="This is not the current assignment.",
)
@current_assignment
def test_no_abs_import():
paths = list(Path().glob("mmp/a2/*.py"))
check_bad_imports(paths)
@current_assignment
def test_main():
assert issubclass(main.MmpNet, torch.nn.Module)
net = main.MmpNet(12)
assert isinstance(net(torch.rand(2, 3, 128, 128)), torch.Tensor)
loader = main.get_dataloader(
is_train=False,
data_root=os.path.join(os.environ["TORCH_HOME"], "datasets"),
batch_size=2,
num_workers=0,
)
x = next(iter(loader))
crit, opt = main.get_criterion_optimizer(net)
assert isinstance(
crit(torch.rand(10, 4), (torch.rand(10) * 2).long()), torch.Tensor
)
assert isinstance(opt, torch.optim.Optimizer)
main.train_epoch
main.eval_epoch
main.main