import torch from a2.main import MmpNet, get_criterion_optimizer, log_epoch_progress, train_epoch, eval_epoch from a3.dataset import get_dataloader def main(): """Put your code for Exercise 3.3 in here""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_epochs = 10 model = MmpNet(num_classes=10).to(device=device) dataloader_train = get_dataloader(path_to_data="/home/ubuntu/mmp_wise2526_franksim/.data/mmp-public-3.2", image_size=244, batch_size=32, num_workers=6, is_train=True) dataloader_eval = get_dataloader(path_to_data="/home/ubuntu/mmp_wise2526_franksim/.data/mmp-public-3.2", image_size=244, batch_size=32, num_workers=6, is_train=False) criterion, optimizer = get_criterion_optimizer(model=model) for epoche in range(train_epochs): log_epoch_progress(epoche, train_epochs, "start") train_epoch( model=model, loader=dataloader_train, optimizer=optimizer, device=device, criterion=criterion, ) eval_epoch( model=model, loader=dataloader_eval, device=device ) log_epoch_progress(epoche, train_epochs, "end") if __name__ == "__main__": main()