63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
|
|
from datetime import datetime
|
||
|
|
import pytest
|
||
|
|
from pathlib import Path
|
||
|
|
from tempfile import TemporaryDirectory
|
||
|
|
import numpy as np
|
||
|
|
from PIL import Image
|
||
|
|
import torch
|
||
|
|
import torchvision
|
||
|
|
|
||
|
|
from . import check_bad_imports
|
||
|
|
from mmp.a1 import main, tensors
|
||
|
|
|
||
|
|
current_assignment = pytest.mark.skipif(
|
||
|
|
not (datetime.now() <= datetime(2025, 10, 22, 23, 59, 59)),
|
||
|
|
reason="This is not the current assignment.",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@current_assignment
|
||
|
|
def test_no_abs_import():
|
||
|
|
paths = list(Path().glob("mmp/a1/*.py"))
|
||
|
|
check_bad_imports(paths)
|
||
|
|
|
||
|
|
|
||
|
|
@current_assignment
|
||
|
|
def test_main():
|
||
|
|
# for testing: generate random image
|
||
|
|
img1 = Image.fromarray((np.random.rand(128, 128, 3) * 255).astype(np.uint8))
|
||
|
|
img2 = Image.fromarray((np.random.rand(128, 128, 3) * 255).astype(np.uint8))
|
||
|
|
tfm = torchvision.transforms.Resize(320)
|
||
|
|
with TemporaryDirectory() as img_dir:
|
||
|
|
img_dir = Path(img_dir)
|
||
|
|
p1 = str(img_dir / "img1.jpg")
|
||
|
|
p2 = str(img_dir / "img2.jpg")
|
||
|
|
img1.save(p1)
|
||
|
|
img2.save(p2)
|
||
|
|
batch = main.build_batch([p1, p2])
|
||
|
|
batch_transformed = main.build_batch([p1, p2], transform=tfm)
|
||
|
|
assert isinstance(batch, torch.Tensor)
|
||
|
|
assert len(batch) == 2
|
||
|
|
|
||
|
|
model = main.get_model()
|
||
|
|
assert isinstance(model, torch.nn.Module)
|
||
|
|
|
||
|
|
main.main
|
||
|
|
|
||
|
|
|
||
|
|
@current_assignment
|
||
|
|
def test_tensors():
|
||
|
|
avg = tensors.avg_color(torch.rand(3, 128, 128))
|
||
|
|
assert isinstance(avg, torch.Tensor)
|
||
|
|
assert len(avg) == 3
|
||
|
|
|
||
|
|
masked = tensors.mask(
|
||
|
|
torch.rand(3, 128, 128), torch.rand(3, 128, 128), torch.rand(128, 128), 0.3
|
||
|
|
)
|
||
|
|
assert isinstance(masked, torch.Tensor)
|
||
|
|
assert masked.shape == (3, 128, 128)
|
||
|
|
|
||
|
|
result = tensors.add_matrix_vector(torch.rand(3, 4), torch.rand(3))
|
||
|
|
assert isinstance(result, torch.Tensor)
|
||
|
|
assert result.shape == (3, 4)
|