Files
mmp_wise2526_franksim/mmp/a2/main.py
2025-10-23 13:09:42 +00:00

197 lines
5.7 KiB
Python

from typing import Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models, datasets, transforms
from torchvision.models import MobileNet_V2_Weights
import logging
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] %(levelname)s: %(message)s',
datefmt='%H:%M:%S'
)
logger = logging.getLogger(__name__)
# these are the labels from the Cifar10 dataset:
CLASSES = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
class MmpNet(nn.Module):
"""Exercise 2.1"""
def __init__(self, num_classes: int):
super().__init__()
self.mobilenet = models.mobilenet_v2(
weights=MobileNet_V2_Weights.DEFAULT)
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.mobilenet.last_channel, num_classes),
)
def forward(self, x: torch.Tensor):
x = self.mobilenet.features(x)
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def get_dataloader(
is_train: bool, data_root: str, batch_size: int, num_workers: int
) -> DataLoader:
"""Exercise 2.2
@param is_train: Whether this is the training or validation split
@param data_root: Where to download the dataset to
@param batch_size: Batch size for the data loader
@param num_workers: Number of workers for the data loader
"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]
),
])
dataset = datasets.CIFAR10(
root=data_root,
train=is_train,
download=True,
transform=transform
)
dataloader = DataLoader(
dataset, batch_size=batch_size,
shuffle=is_train,
num_workers=num_workers,
pin_memory=True
)
return dataloader
def get_criterion_optimizer(model: nn.Module) -> Tuple[nn.Module, optim.Optimizer]:
"""Exercise 2.3a
@param model: The model that is being trained.
@return: Returns a tuple of the criterion and the optimizer.
"""
error_function = nn.CrossEntropyLoss()
epsilon = 0.004
optimizer = torch.optim.SGD(model.parameters(), lr=epsilon)
return (error_function, optimizer)
def log_epoch_progress(epoch: int, total_epochs: int, phase: str):
if phase == "start":
logger.info(f"Epoch {epoch + 1}/{total_epochs} started.")
elif phase == "end":
logger.info(f"Epoch {epoch + 1}/{total_epochs} completed.")
def train_epoch(
model: nn.Module,
loader: DataLoader,
criterion: nn.Module,
optimizer: optim.Optimizer,
device: torch.device,
):
"""Exercise 2.3b
@param model: The model that should be trained
@param loader: The DataLoader that contains the training data
@param criterion: The criterion that is used to calculate the loss for backpropagation
@param optimizer: Executes the update step
@param device: The device where the epoch should run on
"""
model.train()
running_loss = 0.0
log_interval = max(len(loader) // 5, 1)
for batch_idx, (inputs, labels) in enumerate(loader, 1):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
if batch_idx % log_interval == 0 or batch_idx == len(loader):
avg_batch_loss = running_loss / (batch_idx * loader.batch_size)
logger.info(
f" [Batch {batch_idx}/{len(loader)}] Train Loss: {avg_batch_loss:.4f}")
epoch_loss = running_loss / len(loader.dataset)
logger.info(f" ---> Train Loss (Epoch): {epoch_loss:.4f}")
return epoch_loss
def eval_epoch(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
"""Exercise 2.3c
@param model: The model that should be evaluated
@param loader: The DataLoader that contains the evaluation data
@param device: The device where the epoch should run on
@return: Returns the accuracy over the full validation dataset as a float."""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = outputs.max(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
accuracy = correct / total if total > 0 else 0.0
logger.info(f" ---> Eval Accuracy: {accuracy * 100:.2f}%")
return float(accuracy)
def main():
"""Exercise 2.3d"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_epochs = 10
model = MmpNet(num_classes=10).to(device=device)
dataloader_train = get_dataloader(True, "../../.data", 32, 6)
dataloader_eval = get_dataloader(False, "../../.data", 32, 6)
criterion, optimizer = get_criterion_optimizer(model=model)
for epoche in range(train_epochs):
log_epoch_progress(epoche, train_epochs, "start")
train_epoch(
model=model,
loader=dataloader_train,
optimizer=optimizer,
device=device,
criterion=criterion,
)
eval_epoch(
model=model,
loader=dataloader_eval,
device=device
)
log_epoch_progress(epoche, train_epochs, "end")
if __name__ == "__main__":
main()