Files
mmp_wise2526_franksim/tests/test_a1.py
2025-10-13 14:48:00 +02:00

63 lines
1.8 KiB
Python

from datetime import datetime
import pytest
from pathlib import Path
from tempfile import TemporaryDirectory
import numpy as np
from PIL import Image
import torch
import torchvision
from . import check_bad_imports
from mmp.a1 import main, tensors
current_assignment = pytest.mark.skipif(
not (datetime.now() <= datetime(2025, 10, 22, 23, 59, 59)),
reason="This is not the current assignment.",
)
@current_assignment
def test_no_abs_import():
paths = list(Path().glob("mmp/a1/*.py"))
check_bad_imports(paths)
@current_assignment
def test_main():
# for testing: generate random image
img1 = Image.fromarray((np.random.rand(128, 128, 3) * 255).astype(np.uint8))
img2 = Image.fromarray((np.random.rand(128, 128, 3) * 255).astype(np.uint8))
tfm = torchvision.transforms.Resize(320)
with TemporaryDirectory() as img_dir:
img_dir = Path(img_dir)
p1 = str(img_dir / "img1.jpg")
p2 = str(img_dir / "img2.jpg")
img1.save(p1)
img2.save(p2)
batch = main.build_batch([p1, p2])
batch_transformed = main.build_batch([p1, p2], transform=tfm)
assert isinstance(batch, torch.Tensor)
assert len(batch) == 2
model = main.get_model()
assert isinstance(model, torch.nn.Module)
main.main
@current_assignment
def test_tensors():
avg = tensors.avg_color(torch.rand(3, 128, 128))
assert isinstance(avg, torch.Tensor)
assert len(avg) == 3
masked = tensors.mask(
torch.rand(3, 128, 128), torch.rand(3, 128, 128), torch.rand(128, 128), 0.3
)
assert isinstance(masked, torch.Tensor)
assert masked.shape == (3, 128, 128)
result = tensors.add_matrix_vector(torch.rand(3, 4), torch.rand(3))
assert isinstance(result, torch.Tensor)
assert result.shape == (3, 4)