from datetime import datetime import pytest from pathlib import Path import torch import os from . import check_bad_imports from mmp.a2 import main current_assignment = pytest.mark.skipif( not (datetime(2025, 10, 23) <= datetime.now() <= datetime(2025, 10, 29, 23, 59, 59)), reason="This is not the current assignment.", ) @current_assignment def test_no_abs_import(): paths = list(Path().glob("mmp/a2/*.py")) check_bad_imports(paths) @current_assignment def test_main(): assert issubclass(main.MmpNet, torch.nn.Module) net = main.MmpNet(12) assert isinstance(net(torch.rand(2, 3, 128, 128)), torch.Tensor) loader = main.get_dataloader( is_train=False, data_root=os.path.join(os.environ["TORCH_HOME"], "datasets"), batch_size=2, num_workers=0, ) x = next(iter(loader)) crit, opt = main.get_criterion_optimizer(net) assert isinstance( crit(torch.rand(10, 4), (torch.rand(10) * 2).long()), torch.Tensor ) assert isinstance(opt, torch.optim.Optimizer) main.train_epoch main.eval_epoch main.main