Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Last active June 25, 2020 23:02
Show Gist options
  • Save williamFalcon/bc87099b0d3a0c40ef3af7db969c04d6 to your computer and use it in GitHub Desktop.
Save williamFalcon/bc87099b0d3a0c40ef3af7db969c04d6 to your computer and use it in GitHub Desktop.
import os
import pytorch_lightning as pl
from pl_bolts.models.regression import LogisticRegression
from pl_bolts.datamodules import ImagenetDataModule
# use imagenet
imagenet = ImagenetDataModule(data_dir=os.environ['IMGNET_PATH'], meta_root=os.environ['META_ROOT'], image_size=224, num_workers=32)
# input size is channels x height x width
input_dim = 3 * 224 * 224
# logistic regression
model = LogisticRegression(
input_dim=input_dim,
num_classes=imagenet.num_classes,
learning_rate=1e-7
)
# 2 gpus
trainer = pl.Trainer(gpus=2, distributed_backend='ddp', precision=16)
trainer.fit(
model,
imagenet.train_dataloader(batch_size=256),
imagenet.val_dataloader(batch_size=256)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment