2025-10-28 16:03:53 +00:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from a2.main import MmpNet, get_criterion_optimizer, log_epoch_progress, train_epoch, eval_epoch
|
|
|
|
|
from a3.dataset import get_dataloader
|
|
|
|
|
|
|
|
|
|
|
2025-10-13 14:48:00 +02:00
|
|
|
def main():
|
|
|
|
|
"""Put your code for Exercise 3.3 in here"""
|
2025-10-28 16:03:53 +00:00
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
train_epochs = 10
|
|
|
|
|
model = MmpNet(num_classes=10).to(device=device)
|
|
|
|
|
dataloader_train = get_dataloader(path_to_data="/home/ubuntu/mmp_wise2526_franksim/.data/mmp-public-3.2",
|
|
|
|
|
image_size=244, batch_size=32, num_workers=6, is_train=True)
|
|
|
|
|
dataloader_eval = get_dataloader(path_to_data="/home/ubuntu/mmp_wise2526_franksim/.data/mmp-public-3.2",
|
|
|
|
|
image_size=244, batch_size=32, num_workers=6, is_train=False)
|
|
|
|
|
criterion, optimizer = get_criterion_optimizer(model=model)
|
|
|
|
|
|
|
|
|
|
for epoche in range(train_epochs):
|
|
|
|
|
log_epoch_progress(epoche, train_epochs, "start")
|
|
|
|
|
train_epoch(
|
|
|
|
|
model=model,
|
|
|
|
|
loader=dataloader_train,
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
device=device,
|
|
|
|
|
criterion=criterion,
|
|
|
|
|
)
|
|
|
|
|
eval_epoch(
|
|
|
|
|
model=model,
|
|
|
|
|
loader=dataloader_eval,
|
|
|
|
|
device=device
|
|
|
|
|
)
|
|
|
|
|
log_epoch_progress(epoche, train_epochs, "end")
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|