formatting

This commit is contained in:
franksim
2025-11-07 11:20:08 +01:00
parent 8fc3559d6c
commit b159d76517
8 changed files with 119 additions and 82 deletions

View File

@@ -9,8 +9,8 @@ import logging
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] %(levelname)s: %(message)s',
datefmt='%H:%M:%S'
format="[%(asctime)s] %(levelname)s: %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
@@ -34,8 +34,7 @@ class MmpNet(nn.Module):
def __init__(self, num_classes: int):
super().__init__()
self.mobilenet = models.mobilenet_v2(
weights=MobileNet_V2_Weights.DEFAULT)
self.mobilenet = models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.mobilenet.last_channel, num_classes),
@@ -59,24 +58,23 @@ def get_dataloader(
@param batch_size: Batch size for the data loader
@param num_workers: Number of workers for the data loader
"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]
),
])
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]
),
]
)
dataset = datasets.CIFAR10(
root=data_root,
train=is_train,
download=True,
transform=transform
root=data_root, train=is_train, download=True, transform=transform
)
dataloader = DataLoader(
dataset, batch_size=batch_size,
dataset,
batch_size=batch_size,
shuffle=is_train,
num_workers=num_workers,
pin_memory=True
pin_memory=True,
)
return dataloader
@@ -133,7 +131,8 @@ def train_epoch(
if batch_idx % log_interval == 0 or batch_idx == len(loader):
avg_batch_loss = running_loss / (batch_idx * loader.batch_size)
logger.info(
f" [Batch {batch_idx}/{len(loader)}] Train Loss: {avg_batch_loss:.4f}")
f" [Batch {batch_idx}/{len(loader)}] Train Loss: {avg_batch_loss:.4f}"
)
epoch_loss = running_loss / len(loader.dataset)
logger.info(f" ---> Train Loss (Epoch): {epoch_loss:.4f}")
@@ -184,11 +183,7 @@ def main():
device=device,
criterion=criterion,
)
eval_epoch(
model=model,
loader=dataloader_eval,
device=device
)
eval_epoch(model=model, loader=dataloader_eval, device=device)
log_epoch_progress(epoche, train_epochs, "end")