import torch from torchvision import models from torchvision.models import MobileNet_V2_Weights from torch import nn class MmpNet(torch.nn.Module): def __init__(self, num_widths: int, num_aspect_ratios: int, num_classes: int = 2): super().__init__() self.backbone = models.mobilenet_v2( weights=MobileNet_V2_Weights.DEFAULT ).features self.num_widths = num_widths self.num_aspect_ratios = num_aspect_ratios self.num_classes = num_classes with torch.no_grad(): dummy = torch.zeros(1, 3, 224, 224) backbone_out = self.backbone(dummy) in_channels = backbone_out.shape[1] self.head = nn.Conv2d( in_channels=in_channels, kernel_size=3, out_channels=self.get_required_output_channels(), stride=1, padding=1, ) def get_required_output_channels(self): return self.num_widths * self.num_aspect_ratios * self.num_classes def forward(self, x: torch.Tensor): x = self.backbone(x) x = self.head(x) b, out_c, h, w = x.shape x = x.view(b, self.num_widths, self.num_aspect_ratios, self.num_classes, h, w) x = x.permute(0, 1, 2, 4, 5, 3).contiguous() # Now: (batch, num_widths, num_aspect_ratios, h, w, num_classes) return x