update template code
This commit is contained in:
34
tests/test_a5.py
Normal file
34
tests/test_a5.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from . import check_bad_imports
|
||||
from mmp.a5 import model, main
|
||||
|
||||
current_assignment = pytest.mark.skipif(
|
||||
not (datetime(2025, 11, 13) <= datetime.now() <= datetime(2025, 11, 26, 23, 59, 59)),
|
||||
reason="This is not the current assignment.",
|
||||
)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_no_abs_import():
|
||||
paths = list(Path().glob("mmp/a5/*.py"))
|
||||
check_bad_imports(paths)
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_model():
|
||||
net = model.MmpNet(num_widths=4, num_aspect_ratios=2)
|
||||
assert isinstance(net, torch.nn.Module)
|
||||
output = net(torch.rand(2, 3, 224, 224))
|
||||
|
||||
|
||||
@current_assignment
|
||||
def test_main():
|
||||
main.get_random_sampling_mask
|
||||
main.main
|
||||
main.step
|
||||
Reference in New Issue
Block a user