Skip to content

Instantly share code, notes, and snippets.

@devforfu
Created November 8, 2020 19:44
Show Gist options
  • Save devforfu/f90ce58d1d49da0e40aa62e35c097dd5 to your computer and use it in GitHub Desktop.
Save devforfu/f90ce58d1d49da0e40aa62e35c097dd5 to your computer and use it in GitHub Desktop.
Lightning experiment
class BasicCNNExperiment(BaseExperiment):
def create_model(self) -> nn.Module:
return nn.Sequential(
nn.Conv2d(3, 16, 3),
nn.ReLU(inplace=True),
nn.Conv2d(16, 32, 3),
nn.ReLU(inplace=True),
nn.AdaptiveMaxPool2d(1),
nn.Flatten(1, -1),
nn.Linear(32, self.hparams.n_classes)
)
def forward(self, batch: Dict) -> Any:
features, targets = batch
out = self.model(features)
loss = cross_entropy(out, targets)
return {'loss': loss, 'outputs': out}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment