Files
mmp_wise2526_franksim/mmp/a5/model.py
2025-11-16 16:28:13 +01:00

41 lines
1.3 KiB
Python

import torch
from torchvision import models
from torchvision.models import MobileNet_V2_Weights
from torch import nn
class MmpNet(torch.nn.Module):
def __init__(self, num_widths: int, num_aspect_ratios: int, num_classes: int = 2):
super().__init__()
self.backbone = models.mobilenet_v2(
weights=MobileNet_V2_Weights.DEFAULT
).features
self.num_widths = num_widths
self.num_aspect_ratios = num_aspect_ratios
self.num_classes = num_classes
with torch.no_grad():
dummy = torch.zeros(1, 3, 224, 224)
backbone_out = self.backbone(dummy)
in_channels = backbone_out.shape[1]
self.head = nn.Conv2d(
in_channels=in_channels,
kernel_size=3,
out_channels=self.get_required_output_channels(),
stride=1,
padding=1,
)
def get_required_output_channels(self):
return self.num_widths * self.num_aspect_ratios * self.num_classes
def forward(self, x: torch.Tensor):
x = self.backbone(x)
x = self.head(x)
b, out_c, h, w = x.shape
x = x.view(b, self.num_widths, self.num_aspect_ratios, self.num_classes, h, w)
x = x.permute(0, 1, 2, 4, 5, 3).contiguous()
# Now: (batch, num_widths, num_aspect_ratios, h, w, num_classes)
return x