Skip to content

Instantly share code, notes, and snippets.

@rlan
Last active July 14, 2023 10:32
Show Gist options
  • Save rlan/07e9cdc4d395e17618bc17cf22e637a6 to your computer and use it in GitHub Desktop.
Save rlan/07e9cdc4d395e17618bc17cf22e637a6 to your computer and use it in GitHub Desktop.
MNIST training starter kit with Ray, PyTorch Lightning and PyTorch
"""
MNIST PyTorch Lightning Example
Ref: https://docs.ray.io/en/latest/tune/examples/includes/mnist_ptl_mini.html
"""
import math
import torch
from filelock import FileLock
from torch.nn import functional as F
from torchmetrics import Accuracy
import pytorch_lightning as pl
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
import os
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray import air, tune
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self, config, data_dir=None):
super(LightningMNISTClassifier, self).__init__()
self.data_dir = data_dir or os.getcwd()
self.lr = config["lr"]
layer_1, layer_2 = config["layer_1"], config["layer_2"]
self.batch_size = config["batch_size"]
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
self.layer_2 = torch.nn.Linear(layer_1, layer_2)
self.layer_3 = torch.nn.Linear(layer_2, 10)
self.accuracy = Accuracy(task="multiclass", num_classes=10)
def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
self.log("ptl/train_loss", loss)
self.log("ptl/train_accuracy", acc)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
return {"val_loss": loss, "val_accuracy": acc}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
self.log("ptl/val_loss", avg_loss)
self.log("ptl/val_accuracy", avg_acc)
def train_mnist_tune(config, num_epochs=10, num_gpus=0):
data_dir = os.path.abspath("./data")
model = LightningMNISTClassifier(config, data_dir)
with FileLock(os.path.expanduser("~/.data.lock")):
dm = MNISTDataModule(
data_dir=data_dir, num_workers=1, batch_size=config["batch_size"]
)
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
trainer = pl.Trainer(
max_epochs=num_epochs,
# If fractional GPUs passed in, convert to int.
gpus=math.ceil(num_gpus),
enable_progress_bar=False,
callbacks=[TuneReportCallback(metrics, on="validation_end")],
)
trainer.fit(model, dm)
def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0):
config = {
"layer_1": tune.choice([32, 64, 128]),
"layer_2": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
}
trainable = tune.with_parameters(
train_mnist_tune, num_epochs=num_epochs, num_gpus=gpus_per_trial
)
tuner = tune.Tuner(
tune.with_resources(trainable, resources={"cpu": 1, "gpu": gpus_per_trial}),
tune_config=tune.TuneConfig(
metric="loss",
mode="min",
num_samples=num_samples,
),
run_config=air.RunConfig(
name="tune_mnist",
),
param_space=config,
)
results = tuner.fit()
print("Best hyperparameters found were: ", results.get_best_result().config)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing"
)
args, _ = parser.parse_known_args()
if args.smoke_test:
tune_mnist(num_samples=1, num_epochs=1, gpus_per_trial=0)
else:
tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)
absl-py==1.4.0
aiofiles==22.1.0
aiohttp==3.8.4
aiohttp-cors==0.7.0
aiorwlock==1.3.0
aiosignal==1.3.1
aiosqlite==0.19.0
annotated-types==0.5.0
anyio==3.6.2
appnope==0.1.3
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
async-lru==2.0.3
async-timeout==4.0.2
attrs==23.1.0
Babel==2.12.1
backcall==0.2.0
backoff==2.2.1
beautifulsoup4==4.12.2
bleach==6.0.0
blessed==1.20.0
cachetools==5.3.0
certifi==2022.12.7
cffi==1.15.1
charset-normalizer==3.1.0
click==8.1.4
colorful==0.5.5
comm==0.1.3
contourpy==1.0.7
croniter==1.3.15
cycler==0.11.0
dateutils==0.6.12
debugpy==1.6.7
decorator==5.1.1
deepdiff==6.3.1
defusedxml==0.7.1
distlib==0.3.6
executing==1.2.0
fastapi==0.88.0
fastjsonschema==2.16.3
filelock==3.12.0
fonttools==4.39.3
fqdn==1.5.1
frozenlist==1.3.3
fsspec==2023.4.0
google-api-core==2.11.1
google-auth==2.17.3
google-auth-oauthlib==1.0.0
googleapis-common-protos==1.59.1
gpustat==1.1
grpcio==1.49.1
h11==0.14.0
idna==3.4
imageio==2.28.1
importlib-metadata==6.6.0
importlib-resources==5.12.0
inquirer==3.1.3
ipykernel==6.22.0
ipython==8.12.1
ipython-genutils==0.2.0
ipywidgets==8.0.7
isoduration==20.11.0
itsdangerous==2.1.2
jedi==0.18.2
Jinja2==3.1.2
joblib==1.2.0
json5==0.9.11
jsonpointer==2.3
jsonschema==4.17.3
jupyter-events==0.6.3
jupyter-lsp==2.2.0
jupyter-ydoc==0.2.4
jupyter_client==8.2.0
jupyter_core==5.3.0
jupyter_server==2.5.0
jupyter_server_fileid==0.9.0
jupyter_server_terminals==0.4.4
jupyter_server_ydoc==0.8.0
jupyterlab==4.0.2
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.8
jupyterlab_server==2.22.1
kiwisolver==1.4.4
lazy_loader==0.2
lightning==1.9.5
lightning-bolts==0.7.0
lightning-cloud==0.5.37
lightning-utilities==0.8.0
Markdown==3.4.3
markdown-it-py==3.0.0
MarkupSafe==2.1.2
matplotlib==3.7.1
matplotlib-inline==0.1.6
mdurl==0.1.2
mistune==2.0.5
mpmath==1.3.0
msgpack==1.0.5
multidict==6.0.4
nbclassic==0.5.6
nbclient==0.7.4
nbconvert==7.3.1
nbformat==5.8.0
nest-asyncio==1.5.6
networkx==3.1
nltk==3.8.1
notebook==6.5.4
notebook_shim==0.2.3
numpy==1.24.3
nvidia-ml-py==12.535.77
oauthlib==3.2.2
opencensus==0.11.2
opencensus-context==0.1.3
opencv-python-headless==4.8.0.74
ordered-set==4.1.0
packaging==23.1
pandas==2.0.3
pandocfilters==1.5.0
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.5.0
pkgutil_resolve_name==1.3.10
platformdirs==3.5.0
prometheus-client==0.16.0
prompt-toolkit==3.0.38
protobuf==4.22.3
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
py-spy==0.3.14
pyarrow==12.0.1
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycparser==2.21
pydantic==1.10.11
Pygments==2.15.1
PyJWT==2.7.0
pyparsing==3.0.9
pyrsistent==0.19.3
python-dateutil==2.8.2
python-editor==1.0.4
python-json-logger==2.0.7
python-multipart==0.0.6
pytorch-lightning==1.9.5
pytz==2023.3
PyWavelets==1.4.1
PyYAML==6.0
pyzmq==25.0.2
ray==2.5.1
readchar==4.0.5
regex==2023.6.3
requests==2.29.0
requests-oauthlib==1.3.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.4.2
rsa==4.9
scikit-image==0.21.0
scikit-learn==1.3.0
scipy==1.9.1
seaborn==0.12.2
Send2Trash==1.8.2
six==1.16.0
smart-open==6.3.0
sniffio==1.3.0
soupsieve==2.4.1
stack-data==0.6.2
starlette==0.22.0
starsessions==1.3.0
sympy==1.11.1
tensorboard==2.13.0
tensorboard-data-server==0.7.0
tensorboard-plugin-wit==1.8.1
tensorboardX==2.6.1
terminado==0.17.1
threadpoolctl==3.1.0
tifffile==2023.4.12
tinycss2==1.2.1
tomli==2.0.1
torch==2.0.1
torchaudio==2.0.2
torchmetrics==1.0.0
torchvision==0.15.2
tornado==6.3.1
tqdm==4.65.0
traitlets==5.9.0
typing_extensions==4.7.1
tzdata==2023.3
uri-template==1.2.0
urllib3==1.26.15
uvicorn==0.22.0
virtualenv==20.21.0
wcwidth==0.2.6
webcolors==1.13
webencodings==0.5.1
websocket-client==1.5.1
websockets==11.0.3
Werkzeug==2.3.2
widgetsnbextension==4.0.8
y-py==0.5.9
yarl==1.9.2
ypy-websocket==0.8.2
zipp==3.15.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment