adds solutions
This commit is contained in:
@@ -1,9 +1,40 @@
|
||||
import torch
|
||||
from torchvision import models
|
||||
from torchvision.models import MobileNet_V2_Weights
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MmpNet(torch.nn.Module):
|
||||
def __init__(self, num_widths: int, num_aspect_ratios: int):
|
||||
raise NotImplementedError()
|
||||
def __init__(self, num_widths: int, num_aspect_ratios: int, num_classes: int = 2):
|
||||
super().__init__()
|
||||
self.backbone = models.mobilenet_v2(
|
||||
weights=MobileNet_V2_Weights.DEFAULT
|
||||
).features
|
||||
self.num_widths = num_widths
|
||||
self.num_aspect_ratios = num_aspect_ratios
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
with torch.no_grad():
|
||||
dummy = torch.zeros(1, 3, 224, 224)
|
||||
backbone_out = self.backbone(dummy)
|
||||
in_channels = backbone_out.shape[1]
|
||||
|
||||
self.head = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
kernel_size=3,
|
||||
out_channels=self.get_required_output_channels(),
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def get_required_output_channels(self):
|
||||
return self.num_widths * self.num_aspect_ratios * self.num_classes
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.backbone(x)
|
||||
x = self.head(x)
|
||||
b, out_c, h, w = x.shape
|
||||
x = x.view(b, self.num_widths, self.num_aspect_ratios, self.num_classes, h, w)
|
||||
x = x.permute(0, 1, 2, 4, 5, 3).contiguous()
|
||||
# Now: (batch, num_widths, num_aspect_ratios, h, w, num_classes)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user