update template code
This commit is contained in:
42
tests/test_a2.py
Normal file
42
tests/test_a2.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import os
|
||||
|
||||
from . import check_bad_imports
|
||||
from mmp.a2 import main
|
||||
|
||||
current_assignment = pytest.mark.skipif(
|
||||
not (datetime(2025, 10, 23) <= datetime.now() <= datetime(2025, 10, 29, 23, 59, 59)),
|
||||
reason="This is not the current assignment.",
|
||||
)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_no_abs_import():
|
||||
paths = list(Path().glob("mmp/a2/*.py"))
|
||||
check_bad_imports(paths)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_main():
|
||||
assert issubclass(main.MmpNet, torch.nn.Module)
|
||||
net = main.MmpNet(12)
|
||||
assert isinstance(net(torch.rand(2, 3, 128, 128)), torch.Tensor)
|
||||
|
||||
loader = main.get_dataloader(
|
||||
is_train=False,
|
||||
data_root=os.path.join(os.environ["TORCH_HOME"], "datasets"),
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
)
|
||||
x = next(iter(loader))
|
||||
crit, opt = main.get_criterion_optimizer(net)
|
||||
assert isinstance(
|
||||
crit(torch.rand(10, 4), (torch.rand(10) * 2).long()), torch.Tensor
|
||||
)
|
||||
assert isinstance(opt, torch.optim.Optimizer)
|
||||
main.train_epoch
|
||||
main.eval_epoch
|
||||
main.main
|
||||
Reference in New Issue
Block a user