Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Last active January 23, 2021 08:10
Show Gist options
  • Save williamFalcon/cacea583bd78bf7896185823f47fcaf2 to your computer and use it in GitHub Desktop.
Save williamFalcon/cacea583bd78bf7896185823f47fcaf2 to your computer and use it in GitHub Desktop.
import pytorch_lightning as pl
class TransferLearningTemplate(pl.LightningModule):
def __init__(self):
# basically a feature extractor
self.pretrained_model = SomePretrainedModel(load_weights=True)
# a model that uses features to do something you care about
self.finetune_model = nn.Linear(dim, num_classes)
def forward(self, x):
features = self.pretrained_model(x)
out = self.finetune_model(features)
def training_step(self, batch, batch_num):
x, y = batch
features = self.forward(x)
return some_loss(features, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment