Files
mmp_wise2526_franksim/tests/test_a1.py

63 lines
1.8 KiB
Python
Raw Normal View History

2025-10-13 14:48:00 +02:00
from datetime import datetime
import pytest
from pathlib import Path
from tempfile import TemporaryDirectory
import numpy as np
from PIL import Image
import torch
import torchvision
from . import check_bad_imports
from mmp.a1 import main, tensors
current_assignment = pytest.mark.skipif(
not (datetime.now() <= datetime(2025, 10, 22, 23, 59, 59)),
reason="This is not the current assignment.",
)
@current_assignment
def test_no_abs_import():
paths = list(Path().glob("mmp/a1/*.py"))
check_bad_imports(paths)
@current_assignment
def test_main():
# for testing: generate random image
img1 = Image.fromarray((np.random.rand(128, 128, 3) * 255).astype(np.uint8))
img2 = Image.fromarray((np.random.rand(128, 128, 3) * 255).astype(np.uint8))
tfm = torchvision.transforms.Resize(320)
with TemporaryDirectory() as img_dir:
img_dir = Path(img_dir)
p1 = str(img_dir / "img1.jpg")
p2 = str(img_dir / "img2.jpg")
img1.save(p1)
img2.save(p2)
batch = main.build_batch([p1, p2])
batch_transformed = main.build_batch([p1, p2], transform=tfm)
assert isinstance(batch, torch.Tensor)
assert len(batch) == 2
model = main.get_model()
assert isinstance(model, torch.nn.Module)
main.main
@current_assignment
def test_tensors():
avg = tensors.avg_color(torch.rand(3, 128, 128))
assert isinstance(avg, torch.Tensor)
assert len(avg) == 3
masked = tensors.mask(
torch.rand(3, 128, 128), torch.rand(3, 128, 128), torch.rand(128, 128), 0.3
)
assert isinstance(masked, torch.Tensor)
assert masked.shape == (3, 128, 128)
result = tensors.add_matrix_vector(torch.rand(3, 4), torch.rand(3))
assert isinstance(result, torch.Tensor)
assert result.shape == (3, 4)