formatting
This commit is contained in:
@@ -9,8 +9,8 @@ import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='[%(asctime)s] %(levelname)s: %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
format="[%(asctime)s] %(levelname)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,8 +34,7 @@ class MmpNet(nn.Module):
|
||||
|
||||
def __init__(self, num_classes: int):
|
||||
super().__init__()
|
||||
self.mobilenet = models.mobilenet_v2(
|
||||
weights=MobileNet_V2_Weights.DEFAULT)
|
||||
self.mobilenet = models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.mobilenet.last_channel, num_classes),
|
||||
@@ -59,24 +58,23 @@ def get_dataloader(
|
||||
@param batch_size: Batch size for the data loader
|
||||
@param num_workers: Number of workers for the data loader
|
||||
"""
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.4914, 0.4822, 0.4465],
|
||||
std=[0.2023, 0.1994, 0.2010]
|
||||
),
|
||||
])
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]
|
||||
),
|
||||
]
|
||||
)
|
||||
dataset = datasets.CIFAR10(
|
||||
root=data_root,
|
||||
train=is_train,
|
||||
download=True,
|
||||
transform=transform
|
||||
root=data_root, train=is_train, download=True, transform=transform
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size,
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=is_train,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True
|
||||
pin_memory=True,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
@@ -133,7 +131,8 @@ def train_epoch(
|
||||
if batch_idx % log_interval == 0 or batch_idx == len(loader):
|
||||
avg_batch_loss = running_loss / (batch_idx * loader.batch_size)
|
||||
logger.info(
|
||||
f" [Batch {batch_idx}/{len(loader)}] Train Loss: {avg_batch_loss:.4f}")
|
||||
f" [Batch {batch_idx}/{len(loader)}] Train Loss: {avg_batch_loss:.4f}"
|
||||
)
|
||||
|
||||
epoch_loss = running_loss / len(loader.dataset)
|
||||
logger.info(f" ---> Train Loss (Epoch): {epoch_loss:.4f}")
|
||||
@@ -184,11 +183,7 @@ def main():
|
||||
device=device,
|
||||
criterion=criterion,
|
||||
)
|
||||
eval_epoch(
|
||||
model=model,
|
||||
loader=dataloader_eval,
|
||||
device=device
|
||||
)
|
||||
eval_epoch(model=model, loader=dataloader_eval, device=device)
|
||||
log_epoch_progress(epoche, train_epochs, "end")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user