Created
August 14, 2021 08:10
-
-
Save FeryET/f428b6f2233c91397e87e3054a78e234 to your computer and use it in GitHub Desktop.
PneumoniaNet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_pretrained(): | |
pretrained_model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True, | |
progress=True) | |
return pretrained_model.features | |
class PneumoniaNet(nn.Module): | |
def __init__(self, | |
input_dim, | |
finetune=False): | |
super().__init__() | |
self.loss_fn = nn.CrossEntropyLoss( | |
weight=torch.FloatTensor(list(classification_weights.values())) | |
) | |
self.input_dim = torch.as_tensor(input_dim) | |
self.input_encoder = load_pretrained() | |
# Freezing all layers but the last two | |
for param in self.input_encoder[:-2].parameters(): | |
param.requires_grad = False | |
self.output_decoder = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1,1)), | |
nn.Flatten(), | |
nn.Dropout(), | |
nn.Linear(576, 2), | |
) | |
self._init_weights() | |
def _init_weights(self): | |
self.output_decoder.apply(weight_init) | |
def forward(self, x): | |
assert x.shape[-2] == self.input_dim[0] and x.shape[-1] == self.input_dim[1] | |
x = self.input_encoder(x) | |
x = self.output_decoder(x) | |
return x | |
def loss(self, outputs, targets): | |
return self.loss_fn(outputs, targets) | |
def generate_opt(self): | |
params = [ | |
{"params": nn.ModuleList([self.output_decoder]).parameters()}, | |
{"params": self.input_encoder[-2:].parameters(), "lr": PRETRAINED_LR,} | |
] | |
return torch.optim.AdamW( | |
params, | |
lr=LR, | |
weight_decay=WEIGHT_DECAY, | |
betas=BETAS, | |
eps=ADAM_EPS | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment