Files
mmp_wise2526_franksim/mmp/a3/main.py
2025-10-28 16:03:53 +00:00

37 lines
1.3 KiB
Python

import torch
from a2.main import MmpNet, get_criterion_optimizer, log_epoch_progress, train_epoch, eval_epoch
from a3.dataset import get_dataloader
def main():
"""Put your code for Exercise 3.3 in here"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_epochs = 10
model = MmpNet(num_classes=10).to(device=device)
dataloader_train = get_dataloader(path_to_data="/home/ubuntu/mmp_wise2526_franksim/.data/mmp-public-3.2",
image_size=244, batch_size=32, num_workers=6, is_train=True)
dataloader_eval = get_dataloader(path_to_data="/home/ubuntu/mmp_wise2526_franksim/.data/mmp-public-3.2",
image_size=244, batch_size=32, num_workers=6, is_train=False)
criterion, optimizer = get_criterion_optimizer(model=model)
for epoche in range(train_epochs):
log_epoch_progress(epoche, train_epochs, "start")
train_epoch(
model=model,
loader=dataloader_train,
optimizer=optimizer,
device=device,
criterion=criterion,
)
eval_epoch(
model=model,
loader=dataloader_eval,
device=device
)
log_epoch_progress(epoche, train_epochs, "end")
if __name__ == "__main__":
main()