192 lines
5.6 KiB
Python
192 lines
5.6 KiB
Python
from typing import Tuple
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import models, datasets, transforms
|
|
from torchvision.models import MobileNet_V2_Weights
|
|
import logging
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="[%(asctime)s] %(levelname)s: %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# these are the labels from the Cifar10 dataset:
|
|
CLASSES = (
|
|
"plane",
|
|
"car",
|
|
"bird",
|
|
"cat",
|
|
"deer",
|
|
"dog",
|
|
"frog",
|
|
"horse",
|
|
"ship",
|
|
"truck",
|
|
)
|
|
|
|
|
|
class MmpNet(nn.Module):
|
|
"""Exercise 2.1"""
|
|
|
|
def __init__(self, num_classes: int):
|
|
super().__init__()
|
|
self.mobilenet = models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
|
|
self.classifier = nn.Sequential(
|
|
nn.Dropout(0.2),
|
|
nn.Linear(self.mobilenet.last_channel, num_classes),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
x = self.mobilenet.features(x)
|
|
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
|
x = torch.flatten(x, 1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
def get_dataloader(
|
|
is_train: bool, data_root: str, batch_size: int, num_workers: int
|
|
) -> DataLoader:
|
|
"""Exercise 2.2
|
|
|
|
@param is_train: Whether this is the training or validation split
|
|
@param data_root: Where to download the dataset to
|
|
@param batch_size: Batch size for the data loader
|
|
@param num_workers: Number of workers for the data loader
|
|
"""
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]
|
|
),
|
|
]
|
|
)
|
|
dataset = datasets.CIFAR10(
|
|
root=data_root, train=is_train, download=True, transform=transform
|
|
)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=is_train,
|
|
num_workers=num_workers,
|
|
pin_memory=True,
|
|
)
|
|
return dataloader
|
|
|
|
|
|
def get_criterion_optimizer(model: nn.Module) -> Tuple[nn.Module, optim.Optimizer]:
|
|
"""Exercise 2.3a
|
|
|
|
@param model: The model that is being trained.
|
|
@return: Returns a tuple of the criterion and the optimizer.
|
|
"""
|
|
error_function = nn.CrossEntropyLoss()
|
|
epsilon = 0.004
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=epsilon)
|
|
return (error_function, optimizer)
|
|
|
|
|
|
def log_epoch_progress(epoch: int, total_epochs: int, phase: str):
|
|
if phase == "start":
|
|
logger.info(f"Epoch {epoch + 1}/{total_epochs} started.")
|
|
elif phase == "end":
|
|
logger.info(f"Epoch {epoch + 1}/{total_epochs} completed.")
|
|
|
|
|
|
def train_epoch(
|
|
model: nn.Module,
|
|
loader: DataLoader,
|
|
criterion: nn.Module,
|
|
optimizer: optim.Optimizer,
|
|
device: torch.device,
|
|
):
|
|
"""Exercise 2.3b
|
|
|
|
@param model: The model that should be trained
|
|
@param loader: The DataLoader that contains the training data
|
|
@param criterion: The criterion that is used to calculate the loss for backpropagation
|
|
@param optimizer: Executes the update step
|
|
@param device: The device where the epoch should run on
|
|
"""
|
|
model.train()
|
|
running_loss = 0.0
|
|
log_interval = max(len(loader) // 5, 1)
|
|
|
|
for batch_idx, (inputs, labels) in enumerate(loader, 1):
|
|
inputs = inputs.to(device)
|
|
labels = labels.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
running_loss += loss.item() * inputs.size(0)
|
|
|
|
if batch_idx % log_interval == 0 or batch_idx == len(loader):
|
|
avg_batch_loss = running_loss / (batch_idx * loader.batch_size)
|
|
logger.info(
|
|
f" [Batch {batch_idx}/{len(loader)}] Train Loss: {avg_batch_loss:.4f}"
|
|
)
|
|
|
|
epoch_loss = running_loss / len(loader.dataset)
|
|
logger.info(f" ---> Train Loss (Epoch): {epoch_loss:.4f}")
|
|
return epoch_loss
|
|
|
|
|
|
def eval_epoch(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
|
|
"""Exercise 2.3c
|
|
|
|
@param model: The model that should be evaluated
|
|
@param loader: The DataLoader that contains the evaluation data
|
|
@param device: The device where the epoch should run on
|
|
|
|
@return: Returns the accuracy over the full validation dataset as a float."""
|
|
model.eval()
|
|
correct = 0
|
|
total = 0
|
|
|
|
with torch.no_grad():
|
|
for inputs, labels in loader:
|
|
inputs = inputs.to(device)
|
|
labels = labels.to(device)
|
|
outputs = model(inputs)
|
|
_, preds = outputs.max(1)
|
|
correct += (preds == labels).sum().item()
|
|
total += labels.size(0)
|
|
|
|
accuracy = correct / total if total > 0 else 0.0
|
|
logger.info(f" ---> Eval Accuracy: {accuracy * 100:.2f}%")
|
|
return float(accuracy)
|
|
|
|
|
|
def main():
|
|
"""Exercise 2.3d"""
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
train_epochs = 10
|
|
model = MmpNet(num_classes=10).to(device=device)
|
|
dataloader_train = get_dataloader(True, "../../.data", 32, 6)
|
|
dataloader_eval = get_dataloader(False, "../../.data", 32, 6)
|
|
criterion, optimizer = get_criterion_optimizer(model=model)
|
|
|
|
for epoche in range(train_epochs):
|
|
log_epoch_progress(epoche, train_epochs, "start")
|
|
train_epoch(
|
|
model=model,
|
|
loader=dataloader_train,
|
|
optimizer=optimizer,
|
|
device=device,
|
|
criterion=criterion,
|
|
)
|
|
eval_epoch(model=model, loader=dataloader_eval, device=device)
|
|
log_epoch_progress(epoche, train_epochs, "end")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|