update template code
This commit is contained in:
89
mmp/a2/main.py
Normal file
89
mmp/a2/main.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# these are the labels from the Cifar10 dataset:
|
||||
CLASSES = (
|
||||
"plane",
|
||||
"car",
|
||||
"bird",
|
||||
"cat",
|
||||
"deer",
|
||||
"dog",
|
||||
"frog",
|
||||
"horse",
|
||||
"ship",
|
||||
"truck",
|
||||
)
|
||||
|
||||
|
||||
class MmpNet(nn.Module):
|
||||
"""Exercise 2.1"""
|
||||
|
||||
def __init__(self, num_classes: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_dataloader(
|
||||
is_train: bool, data_root: str, batch_size: int, num_workers: int
|
||||
) -> DataLoader:
|
||||
"""Exercise 2.2
|
||||
|
||||
@param is_train: Whether this is the training or validation split
|
||||
@param data_root: Where to download the dataset to
|
||||
@param batch_size: Batch size for the data loader
|
||||
@param num_workers: Number of workers for the data loader
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_criterion_optimizer(model: nn.Module) -> Tuple[nn.Module, optim.Optimizer]:
|
||||
"""Exercise 2.3a
|
||||
|
||||
@param model: The model that is being trained.
|
||||
@return: Returns a tuple of the criterion and the optimizer.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def train_epoch(
|
||||
model: nn.Module,
|
||||
loader: DataLoader,
|
||||
criterion: nn.Module,
|
||||
optimizer: optim.Optimizer,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Exercise 2.3b
|
||||
|
||||
@param model: The model that should be trained
|
||||
@param loader: The DataLoader that contains the training data
|
||||
@param criterion: The criterion that is used to calculate the loss for backpropagation
|
||||
@param optimizer: Executes the update step
|
||||
@param device: The device where the epoch should run on
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def eval_epoch(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
|
||||
"""Exercise 2.3c
|
||||
|
||||
@param model: The model that should be evaluated
|
||||
@param loader: The DataLoader that contains the evaluation data
|
||||
@param device: The device where the epoch should run on
|
||||
|
||||
@return: Returns the accuracy over the full validation dataset as a float."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def main():
|
||||
"""Exercise 2.3d"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user