update template code

This commit is contained in:
phatakmr
2025-10-13 14:48:00 +02:00
parent c9d159fcc6
commit 8f637a4a0d
46 changed files with 2955 additions and 1 deletions

89
mmp/a2/main.py Normal file
View File

@@ -0,0 +1,89 @@
from typing import Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# these are the labels from the Cifar10 dataset:
CLASSES = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
class MmpNet(nn.Module):
"""Exercise 2.1"""
def __init__(self, num_classes: int):
raise NotImplementedError()
def forward(self, x: torch.Tensor):
raise NotImplementedError()
def get_dataloader(
is_train: bool, data_root: str, batch_size: int, num_workers: int
) -> DataLoader:
"""Exercise 2.2
@param is_train: Whether this is the training or validation split
@param data_root: Where to download the dataset to
@param batch_size: Batch size for the data loader
@param num_workers: Number of workers for the data loader
"""
raise NotImplementedError()
def get_criterion_optimizer(model: nn.Module) -> Tuple[nn.Module, optim.Optimizer]:
"""Exercise 2.3a
@param model: The model that is being trained.
@return: Returns a tuple of the criterion and the optimizer.
"""
raise NotImplementedError()
def train_epoch(
model: nn.Module,
loader: DataLoader,
criterion: nn.Module,
optimizer: optim.Optimizer,
device: torch.device,
):
"""Exercise 2.3b
@param model: The model that should be trained
@param loader: The DataLoader that contains the training data
@param criterion: The criterion that is used to calculate the loss for backpropagation
@param optimizer: Executes the update step
@param device: The device where the epoch should run on
"""
raise NotImplementedError()
def eval_epoch(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
"""Exercise 2.3c
@param model: The model that should be evaluated
@param loader: The DataLoader that contains the evaluation data
@param device: The device where the epoch should run on
@return: Returns the accuracy over the full validation dataset as a float."""
raise NotImplementedError()
def main():
"""Exercise 2.3d"""
raise NotImplementedError()
if __name__ == "__main__":
main()