Skip to content

Instantly share code, notes, and snippets.

View williamFalcon's full-sized avatar
🎯
Focusing

William Falcon williamFalcon

🎯
Focusing
View GitHub Profile
import pytorch_lightning as pl
class TransferLearningTemplate(pl.LightningModule):
def __init__(self):
# basically a feature extractor
self.pretrained_model = SomePretrainedModel(load_weights=True)
# a model that uses features to do something you care about
self.finetune_model = nn.Linear(dim, num_classes)
#!/bin/bash -l
# SLURM SUBMIT SCRIPT
#SBATCH --nodes=32
#SBATCH --gres=gpu:8
#SBATCH --ntasks-per-node=8
#SBATCH --mem=0
#SBATCH --time=0-02:00:00
# activate conda env
import pytorch_lightning as pl
def main(hparams):
# init model
model = FastStyleTransfer(hparams)
trainer = pl.Trainer(gpus=8, distributed_backend='ddp', nb_gpu_nodes=32)
trainer.fit()
import pytorch_lightning as pl
def main(hparams):
# init model
model = FastStyleTransfer(hparams)
trainer = pl.Trainer(gpus=4)
trainer.fit()
import pytorch_lightning as pl
def main(hparams):
# init model
model = FastStyleTransfer(hparams)
trainer = pl.Trainer()
trainer.fit()
import torch
import pytorch_lightning as pl
class FastStyleTransfer(pl.LightningModule):
def __init__(self, hparams):
self.hparams = hparams
self.transformer = TransformerNet()
self.vgg = Vgg16(requires_grad=False)
train_arg_parser.add_argument("--epochs", type=int, default=2,
help="number of training epochs, default is 2")
train_arg_parser.add_argument("--batch-size", type=int, default=4,
help="batch size for training, default is 4")
train_arg_parser.add_argument("--dataset", type=str, required=True,
help="path to training dataset, the path should point to a folder "
"containing another folder with all the training images")
train_arg_parser.add_argument("--style-image", type=str, default="images/style-images/mosaic.jpg",
help="path to style-image")
@pl.data_loader
def train_dataloader(self):
# REQUIRED
return DataLoader(...)
@pl.data_loader
def val_dataloader(self):
# OPTIONAL
return DataLoader(...)
@pl.data_loader
def train_dataloader(self):
transform = transforms.Compose([
transforms.Resize(self.hparams.image_size),
transforms.CenterCrop(self.hparams.image_size),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
train_dataset = datasets.ImageFolder(args.dataset, transform)
train_loader = DataLoader(train_dataset, batch_size=self.hparams.batch_size)
def configure_optimizers(self):
return torch.optim.Adam(self.transformer.parameters(), self.hparams.lr)