diff --git a/mmp/a2/main.py b/mmp/a2/main.py index 0102036..49451db 100644 --- a/mmp/a2/main.py +++ b/mmp/a2/main.py @@ -3,6 +3,16 @@ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader +from torchvision import models, datasets, transforms +from torchvision.models import MobileNet_V2_Weights +import logging + +logging.basicConfig( + level=logging.INFO, + format='[%(asctime)s] %(levelname)s: %(message)s', + datefmt='%H:%M:%S' +) +logger = logging.getLogger(__name__) # these are the labels from the Cifar10 dataset: CLASSES = ( @@ -23,10 +33,20 @@ class MmpNet(nn.Module): """Exercise 2.1""" def __init__(self, num_classes: int): - raise NotImplementedError() + super().__init__() + self.mobilenet = models.mobilenet_v2( + weights=MobileNet_V2_Weights.DEFAULT) + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.mobilenet.last_channel, num_classes), + ) def forward(self, x: torch.Tensor): - raise NotImplementedError() + x = self.mobilenet.features(x) + x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x def get_dataloader( @@ -39,7 +59,26 @@ def get_dataloader( @param batch_size: Batch size for the data loader @param num_workers: Number of workers for the data loader """ - raise NotImplementedError() + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.4914, 0.4822, 0.4465], + std=[0.2023, 0.1994, 0.2010] + ), + ]) + dataset = datasets.CIFAR10( + root=data_root, + train=is_train, + download=True, + transform=transform + ) + dataloader = DataLoader( + dataset, batch_size=batch_size, + shuffle=is_train, + num_workers=num_workers, + pin_memory=True + ) + return dataloader def get_criterion_optimizer(model: nn.Module) -> Tuple[nn.Module, optim.Optimizer]: @@ -48,7 +87,17 @@ def get_criterion_optimizer(model: nn.Module) -> Tuple[nn.Module, optim.Optimize @param model: The model that is being trained. @return: Returns a tuple of the criterion and the optimizer. """ - raise NotImplementedError() + error_function = nn.CrossEntropyLoss() + epsilon = 0.004 + optimizer = torch.optim.SGD(model.parameters(), lr=epsilon) + return (error_function, optimizer) + + +def log_epoch_progress(epoch: int, total_epochs: int, phase: str): + if phase == "start": + logger.info(f"Epoch {epoch + 1}/{total_epochs} started.") + elif phase == "end": + logger.info(f"Epoch {epoch + 1}/{total_epochs} completed.") def train_epoch( @@ -66,7 +115,29 @@ def train_epoch( @param optimizer: Executes the update step @param device: The device where the epoch should run on """ - raise NotImplementedError() + model.train() + running_loss = 0.0 + log_interval = max(len(loader) // 5, 1) + + for batch_idx, (inputs, labels) in enumerate(loader, 1): + inputs = inputs.to(device) + labels = labels.to(device) + + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + running_loss += loss.item() * inputs.size(0) + + if batch_idx % log_interval == 0 or batch_idx == len(loader): + avg_batch_loss = running_loss / (batch_idx * loader.batch_size) + logger.info( + f" [Batch {batch_idx}/{len(loader)}] Train Loss: {avg_batch_loss:.4f}") + + epoch_loss = running_loss / len(loader.dataset) + logger.info(f" ---> Train Loss (Epoch): {epoch_loss:.4f}") + return epoch_loss def eval_epoch(model: nn.Module, loader: DataLoader, device: torch.device) -> float: @@ -77,12 +148,48 @@ def eval_epoch(model: nn.Module, loader: DataLoader, device: torch.device) -> fl @param device: The device where the epoch should run on @return: Returns the accuracy over the full validation dataset as a float.""" - raise NotImplementedError() + model.eval() + correct = 0 + total = 0 + + with torch.no_grad(): + for inputs, labels in loader: + inputs = inputs.to(device) + labels = labels.to(device) + outputs = model(inputs) + _, preds = outputs.max(1) + correct += (preds == labels).sum().item() + total += labels.size(0) + + accuracy = correct / total if total > 0 else 0.0 + logger.info(f" ---> Eval Accuracy: {accuracy * 100:.2f}%") + return float(accuracy) def main(): """Exercise 2.3d""" - raise NotImplementedError() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + train_epochs = 10 + model = MmpNet(num_classes=10).to(device=device) + dataloader_train = get_dataloader(True, "../../.data", 32, 6) + dataloader_eval = get_dataloader(False, "../../.data", 32, 6) + criterion, optimizer = get_criterion_optimizer(model=model) + + for epoche in range(train_epochs): + log_epoch_progress(epoche, train_epochs, "start") + train_epoch( + model=model, + loader=dataloader_train, + optimizer=optimizer, + device=device, + criterion=criterion, + ) + eval_epoch( + model=model, + loader=dataloader_eval, + device=device + ) + log_epoch_progress(epoche, train_epochs, "end") if __name__ == "__main__":