2025-10-13 14:48:00 +02:00
|
|
|
import torch
|
2025-11-16 16:28:13 +01:00
|
|
|
from torchvision import models
|
|
|
|
|
from torchvision.models import MobileNet_V2_Weights
|
|
|
|
|
from torch import nn
|
2025-10-13 14:48:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class MmpNet(torch.nn.Module):
|
2025-11-16 16:28:13 +01:00
|
|
|
def __init__(self, num_widths: int, num_aspect_ratios: int, num_classes: int = 2):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.backbone = models.mobilenet_v2(
|
|
|
|
|
weights=MobileNet_V2_Weights.DEFAULT
|
|
|
|
|
).features
|
|
|
|
|
self.num_widths = num_widths
|
|
|
|
|
self.num_aspect_ratios = num_aspect_ratios
|
|
|
|
|
self.num_classes = num_classes
|
2025-10-13 14:48:00 +02:00
|
|
|
|
2025-11-16 16:28:13 +01:00
|
|
|
with torch.no_grad():
|
|
|
|
|
dummy = torch.zeros(1, 3, 224, 224)
|
|
|
|
|
backbone_out = self.backbone(dummy)
|
|
|
|
|
in_channels = backbone_out.shape[1]
|
|
|
|
|
|
|
|
|
|
self.head = nn.Conv2d(
|
|
|
|
|
in_channels=in_channels,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
out_channels=self.get_required_output_channels(),
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=1,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def get_required_output_channels(self):
|
|
|
|
|
return self.num_widths * self.num_aspect_ratios * self.num_classes
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
|
|
x = self.backbone(x)
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
b, out_c, h, w = x.shape
|
|
|
|
|
x = x.view(b, self.num_widths, self.num_aspect_ratios, self.num_classes, h, w)
|
|
|
|
|
x = x.permute(0, 1, 2, 4, 5, 3).contiguous()
|
|
|
|
|
# Now: (batch, num_widths, num_aspect_ratios, h, w, num_classes)
|
|
|
|
|
return x
|