formatting
This commit is contained in:
@@ -7,8 +7,9 @@ from .dataset import get_dataloader
|
||||
def main():
|
||||
"""Put your code for Exercise 3.3 in here"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--tensorboard', action='store_true',
|
||||
help='Enable TensorBoard logging')
|
||||
parser.add_argument(
|
||||
"--tensorboard", action="store_true", help="Enable TensorBoard logging"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -16,17 +17,24 @@ def main():
|
||||
model = MmpNet(num_classes=2).to(device=device)
|
||||
dataloader_train = get_dataloader(
|
||||
path_to_data=".data/mmp-public-3.2",
|
||||
image_size=244, batch_size=32, num_workers=6, is_train=True
|
||||
image_size=244,
|
||||
batch_size=32,
|
||||
num_workers=6,
|
||||
is_train=True,
|
||||
)
|
||||
dataloader_eval = get_dataloader(
|
||||
path_to_data=".data/mmp-public-3.2",
|
||||
image_size=244, batch_size=32, num_workers=6, is_train=False
|
||||
image_size=244,
|
||||
batch_size=32,
|
||||
num_workers=6,
|
||||
is_train=False,
|
||||
)
|
||||
criterion, optimizer = get_criterion_optimizer(model=model)
|
||||
|
||||
writer = None
|
||||
if args.tensorboard:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
writer = SummaryWriter(log_dir="runs/a3_mmpnet")
|
||||
|
||||
for epoch in range(train_epochs):
|
||||
@@ -37,14 +45,11 @@ def main():
|
||||
device=device,
|
||||
criterion=criterion,
|
||||
)
|
||||
val_acc = eval_epoch(
|
||||
model=model,
|
||||
loader=dataloader_eval,
|
||||
device=device
|
||||
)
|
||||
val_acc = eval_epoch(model=model, loader=dataloader_eval, device=device)
|
||||
|
||||
print(
|
||||
f"Epoch [{epoch+1}/{train_epochs}] - Train Loss: {train_loss:.4f} - Val Acc: {val_acc:.4f}")
|
||||
f"Epoch [{epoch + 1}/{train_epochs}] - Train Loss: {train_loss:.4f} - Val Acc: {val_acc:.4f}"
|
||||
)
|
||||
|
||||
if writer is not None:
|
||||
writer.add_scalar("Loss/train", train_loss, epoch)
|
||||
|
||||
Reference in New Issue
Block a user