from typing import Tuple import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import models, datasets, transforms from torchvision.models import MobileNet_V2_Weights import logging logging.basicConfig( level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s", datefmt="%H:%M:%S", ) logger = logging.getLogger(__name__) # these are the labels from the Cifar10 dataset: CLASSES = ( "plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ) class MmpNet(nn.Module): """Exercise 2.1""" def __init__(self, num_classes: int): super().__init__() self.mobilenet = models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT) self.classifier = nn.Sequential( nn.Dropout(0.2), nn.Linear(self.mobilenet.last_channel, num_classes), ) def forward(self, x: torch.Tensor): x = self.mobilenet.features(x) x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) x = torch.flatten(x, 1) x = self.classifier(x) return x def get_dataloader( is_train: bool, data_root: str, batch_size: int, num_workers: int ) -> DataLoader: """Exercise 2.2 @param is_train: Whether this is the training or validation split @param data_root: Where to download the dataset to @param batch_size: Batch size for the data loader @param num_workers: Number of workers for the data loader """ transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010] ), ] ) dataset = datasets.CIFAR10( root=data_root, train=is_train, download=True, transform=transform ) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=is_train, num_workers=num_workers, pin_memory=True, ) return dataloader def get_criterion_optimizer(model: nn.Module) -> Tuple[nn.Module, optim.Optimizer]: """Exercise 2.3a @param model: The model that is being trained. @return: Returns a tuple of the criterion and the optimizer. """ error_function = nn.CrossEntropyLoss() epsilon = 0.004 optimizer = torch.optim.SGD(model.parameters(), lr=epsilon) return (error_function, optimizer) def log_epoch_progress(epoch: int, total_epochs: int, phase: str): if phase == "start": logger.info(f"Epoch {epoch + 1}/{total_epochs} started.") elif phase == "end": logger.info(f"Epoch {epoch + 1}/{total_epochs} completed.") def train_epoch( model: nn.Module, loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer, device: torch.device, ): """Exercise 2.3b @param model: The model that should be trained @param loader: The DataLoader that contains the training data @param criterion: The criterion that is used to calculate the loss for backpropagation @param optimizer: Executes the update step @param device: The device where the epoch should run on """ model.train() running_loss = 0.0 log_interval = max(len(loader) // 5, 1) for batch_idx, (inputs, labels) in enumerate(loader, 1): inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) if batch_idx % log_interval == 0 or batch_idx == len(loader): avg_batch_loss = running_loss / (batch_idx * loader.batch_size) logger.info( f" [Batch {batch_idx}/{len(loader)}] Train Loss: {avg_batch_loss:.4f}" ) epoch_loss = running_loss / len(loader.dataset) logger.info(f" ---> Train Loss (Epoch): {epoch_loss:.4f}") return epoch_loss def eval_epoch(model: nn.Module, loader: DataLoader, device: torch.device) -> float: """Exercise 2.3c @param model: The model that should be evaluated @param loader: The DataLoader that contains the evaluation data @param device: The device where the epoch should run on @return: Returns the accuracy over the full validation dataset as a float.""" model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in loader: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = outputs.max(1) correct += (preds == labels).sum().item() total += labels.size(0) accuracy = correct / total if total > 0 else 0.0 logger.info(f" ---> Eval Accuracy: {accuracy * 100:.2f}%") return float(accuracy) def main(): """Exercise 2.3d""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_epochs = 10 model = MmpNet(num_classes=10).to(device=device) dataloader_train = get_dataloader(True, "../../.data", 32, 6) dataloader_eval = get_dataloader(False, "../../.data", 32, 6) criterion, optimizer = get_criterion_optimizer(model=model) for epoche in range(train_epochs): log_epoch_progress(epoche, train_epochs, "start") train_epoch( model=model, loader=dataloader_train, optimizer=optimizer, device=device, criterion=criterion, ) eval_epoch(model=model, loader=dataloader_eval, device=device) log_epoch_progress(epoche, train_epochs, "end") if __name__ == "__main__": main()