This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytorch_lightning as pl | |
from pl_bolts.models.self_supervised import SimCLR | |
from pl_bolts.datamodules import ImagenetDataModule | |
from pl_bolts.models.self_supervised.simclr.transforms import SimCLRTrainDataTransform, SimCLREvalDataTransform | |
# data | |
datamodule = ImagenetDataModule(image_size=196) | |
# transforms | |
(c, h, w) = datamodule.size() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.utils.data as tud | |
import torch | |
from typing import List | |
import random | |
import nlp | |
def prepare_dataset(tokenizer, split="train", max_length=120, num_datapoints=100_000): | |
"""Prepares WikiText-103 dataset""" | |
wikitext = nlp.load_dataset("wikitext", "wikitext-103-v1") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pytorch_lightning as pl | |
from pl_bolts.models.regression import LogisticRegression | |
from pl_bolts.datamodules import ImagenetDataModule | |
# use imagenet | |
imagenet = ImagenetDataModule(data_dir=os.environ['IMGNET_PATH'], meta_root=os.environ['META_ROOT'], image_size=224, num_workers=32) | |
# input size is channels x height x width | |
input_dim = 3 * 224 * 224 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytorch_lightning.metrics.functional as plm | |
pred = torch.tensor([0, 1, 2, 3]) | |
target = torch.tensor([0, 1, 2, 2]) | |
# many popular classification metrics and more | |
plm.accuracy(pred, target) | |
plm.auc(pred, target) | |
plm.auroc(pred, target) | |
plm.average_precision(pred, target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
# | |
# script to extract ImageNet dataset | |
# ILSVRC2012_img_train.tar (about 138 GB) | |
# ILSVRC2012_img_val.tar (about 6.3 GB) | |
# make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar in your current directory | |
# | |
# https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md | |
# | |
# train/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_lightning.loggers import Tensorboard, NeptuneLogger | |
neptune = NeptuneLogger() | |
tensorboard = Tensorboard() | |
model = ... | |
trainer = Trainer(logger=[neptune, tensorboard]) | |
trainer.fit(model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train = DataLoader(...) | |
test = DataLoader(...) | |
val = DataLoader(...) | |
trainer = Trainer() | |
model = LightningModule() | |
trainer.fit(model, | |
train_dataloader=train, | |
val_dataloader=val, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytorch_lightning as pl | |
class MyAPICallback(pl.Callback): | |
def on_init_start(self, trainer): | |
requests.post('model started') | |
def on_init_end(self, trainer): | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytorch_lightning as pl | |
class MyLoggingCallback(pl.Callback): | |
def on_init_start(self, trainer): | |
trainer.logger.experiment.log_tensorboard_images(...) | |
def on_init_end(self, trainer): | |
trainer.logger.experiment.save_or_something(...) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytorch_lightning as pl | |
class MyPrintingCallback(pl.Callback): | |
def on_init_start(self, trainer): | |
print('Starting to init trainer!') | |
def on_init_end(self, trainer): | |
print('trainer is init now') |
NewerOlder