assignment-a3: use tensorboard for logging
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,5 +1,6 @@
|
|||||||
.venv/
|
.venv/
|
||||||
.data
|
.data
|
||||||
|
runs/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.code-workspace
|
*.code-workspace
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|||||||
@@ -1,35 +1,57 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from a2.main import MmpNet, get_criterion_optimizer, log_epoch_progress, train_epoch, eval_epoch
|
import argparse
|
||||||
|
from a2.main import MmpNet, get_criterion_optimizer, train_epoch, eval_epoch
|
||||||
from a3.dataset import get_dataloader
|
from a3.dataset import get_dataloader
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Put your code for Exercise 3.3 in here"""
|
"""Put your code for Exercise 3.3 in here"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--tensorboard', action='store_true',
|
||||||
|
help='Enable TensorBoard logging')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
train_epochs = 10
|
train_epochs = 10
|
||||||
model = MmpNet(num_classes=10).to(device=device)
|
model = MmpNet(num_classes=2).to(device=device)
|
||||||
dataloader_train = get_dataloader(path_to_data="/home/ubuntu/mmp_wise2526_franksim/.data/mmp-public-3.2",
|
dataloader_train = get_dataloader(
|
||||||
image_size=244, batch_size=32, num_workers=6, is_train=True)
|
path_to_data="/home/ubuntu/mmp_wise2526_franksim/.data/mmp-public-3.2",
|
||||||
dataloader_eval = get_dataloader(path_to_data="/home/ubuntu/mmp_wise2526_franksim/.data/mmp-public-3.2",
|
image_size=244, batch_size=32, num_workers=6, is_train=True
|
||||||
image_size=244, batch_size=32, num_workers=6, is_train=False)
|
)
|
||||||
|
dataloader_eval = get_dataloader(
|
||||||
|
path_to_data="/home/ubuntu/mmp_wise2526_franksim/.data/mmp-public-3.2",
|
||||||
|
image_size=244, batch_size=32, num_workers=6, is_train=False
|
||||||
|
)
|
||||||
criterion, optimizer = get_criterion_optimizer(model=model)
|
criterion, optimizer = get_criterion_optimizer(model=model)
|
||||||
|
|
||||||
for epoche in range(train_epochs):
|
writer = None
|
||||||
log_epoch_progress(epoche, train_epochs, "start")
|
if args.tensorboard:
|
||||||
train_epoch(
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
writer = SummaryWriter(log_dir="runs/a3_mmpnet")
|
||||||
|
|
||||||
|
for epoch in range(train_epochs):
|
||||||
|
train_loss = train_epoch(
|
||||||
model=model,
|
model=model,
|
||||||
loader=dataloader_train,
|
loader=dataloader_train,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
device=device,
|
device=device,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
)
|
)
|
||||||
eval_epoch(
|
val_acc = eval_epoch(
|
||||||
model=model,
|
model=model,
|
||||||
loader=dataloader_eval,
|
loader=dataloader_eval,
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
log_epoch_progress(epoche, train_epochs, "end")
|
|
||||||
|
print(
|
||||||
|
f"Epoch [{epoch+1}/{train_epochs}] - Train Loss: {train_loss:.4f} - Val Acc: {val_acc:.4f}")
|
||||||
|
|
||||||
|
if writer is not None:
|
||||||
|
writer.add_scalar("Loss/train", train_loss, epoch)
|
||||||
|
writer.add_scalar("Accuracy/val", val_acc, epoch)
|
||||||
|
|
||||||
|
if writer is not None:
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user