update template code
This commit is contained in:
39
mmp/a5/main.py
Normal file
39
mmp/a5/main.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from .model import MmpNet
|
||||
|
||||
|
||||
def step(
|
||||
model: MmpNet,
|
||||
criterion,
|
||||
optimizer: optim.Optimizer,
|
||||
img_batch: torch.Tensor,
|
||||
lbl_batch: torch.Tensor,
|
||||
) -> float:
|
||||
"""Performs one update step for the model
|
||||
|
||||
@return: The loss for the specified batch. Return a float and not a PyTorch tensor
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_random_sampling_mask(labels: torch.Tensor, neg_ratio: float) -> torch.Tensor:
|
||||
"""
|
||||
@param labels: The label tensor that is returned by your data loader.
|
||||
The values are either 0 (negative label) or 1 (positive label).
|
||||
@param neg_ratio: The desired negative/positive ratio.
|
||||
Hint: after computing the mask, check if the neg_ratio is fulfilled.
|
||||
@return: A tensor with the same shape as labels
|
||||
"""
|
||||
assert labels.min() >= 0 and labels.max() <= 1 # remove this line if you want
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def main():
|
||||
"""Put your training code for exercises 5.2 and 5.3 here"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
9
mmp/a5/model.py
Normal file
9
mmp/a5/model.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import torch
|
||||
|
||||
|
||||
class MmpNet(torch.nn.Module):
|
||||
def __init__(self, num_widths: int, num_aspect_ratios: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
Reference in New Issue
Block a user