2025-10-28 16:03:53 +00:00
|
|
|
import torch
|
2025-10-31 13:17:46 +00:00
|
|
|
import argparse
|
2025-10-31 14:02:28 +00:00
|
|
|
from ..a2.main import MmpNet, get_criterion_optimizer, train_epoch, eval_epoch
|
|
|
|
|
from .dataset import get_dataloader
|
2025-10-28 16:03:53 +00:00
|
|
|
|
|
|
|
|
|
2025-10-13 14:48:00 +02:00
|
|
|
def main():
|
|
|
|
|
"""Put your code for Exercise 3.3 in here"""
|
2025-10-31 13:17:46 +00:00
|
|
|
parser = argparse.ArgumentParser()
|
2025-11-07 11:20:08 +01:00
|
|
|
parser.add_argument(
|
|
|
|
|
"--tensorboard", action="store_true", help="Enable TensorBoard logging"
|
|
|
|
|
)
|
2025-10-31 13:17:46 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
2025-10-28 16:03:53 +00:00
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
train_epochs = 10
|
2025-10-31 13:17:46 +00:00
|
|
|
model = MmpNet(num_classes=2).to(device=device)
|
|
|
|
|
dataloader_train = get_dataloader(
|
2025-11-04 09:35:16 +01:00
|
|
|
path_to_data=".data/mmp-public-3.2",
|
2025-11-07 11:20:08 +01:00
|
|
|
image_size=244,
|
|
|
|
|
batch_size=32,
|
|
|
|
|
num_workers=6,
|
|
|
|
|
is_train=True,
|
2025-10-31 13:17:46 +00:00
|
|
|
)
|
|
|
|
|
dataloader_eval = get_dataloader(
|
2025-11-04 09:35:16 +01:00
|
|
|
path_to_data=".data/mmp-public-3.2",
|
2025-11-07 11:20:08 +01:00
|
|
|
image_size=244,
|
|
|
|
|
batch_size=32,
|
|
|
|
|
num_workers=6,
|
|
|
|
|
is_train=False,
|
2025-10-31 13:17:46 +00:00
|
|
|
)
|
2025-10-28 16:03:53 +00:00
|
|
|
criterion, optimizer = get_criterion_optimizer(model=model)
|
|
|
|
|
|
2025-10-31 13:17:46 +00:00
|
|
|
writer = None
|
|
|
|
|
if args.tensorboard:
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
2025-11-07 11:20:08 +01:00
|
|
|
|
2025-10-31 13:17:46 +00:00
|
|
|
writer = SummaryWriter(log_dir="runs/a3_mmpnet")
|
|
|
|
|
|
|
|
|
|
for epoch in range(train_epochs):
|
|
|
|
|
train_loss = train_epoch(
|
2025-10-28 16:03:53 +00:00
|
|
|
model=model,
|
|
|
|
|
loader=dataloader_train,
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
device=device,
|
|
|
|
|
criterion=criterion,
|
|
|
|
|
)
|
2025-11-07 11:20:08 +01:00
|
|
|
val_acc = eval_epoch(model=model, loader=dataloader_eval, device=device)
|
2025-10-31 13:17:46 +00:00
|
|
|
|
|
|
|
|
print(
|
2025-11-07 11:20:08 +01:00
|
|
|
f"Epoch [{epoch + 1}/{train_epochs}] - Train Loss: {train_loss:.4f} - Val Acc: {val_acc:.4f}"
|
|
|
|
|
)
|
2025-10-31 13:17:46 +00:00
|
|
|
|
|
|
|
|
if writer is not None:
|
|
|
|
|
writer.add_scalar("Loss/train", train_loss, epoch)
|
|
|
|
|
writer.add_scalar("Accuracy/val", val_acc, epoch)
|
|
|
|
|
|
|
|
|
|
if writer is not None:
|
|
|
|
|
writer.close()
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|