Last active
July 14, 2023 10:32
-
-
Save rlan/07e9cdc4d395e17618bc17cf22e637a6 to your computer and use it in GitHub Desktop.
MNIST training starter kit with Ray, PyTorch Lightning and PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
MNIST PyTorch Lightning Example | |
Ref: https://docs.ray.io/en/latest/tune/examples/includes/mnist_ptl_mini.html | |
""" | |
import math | |
import torch | |
from filelock import FileLock | |
from torch.nn import functional as F | |
from torchmetrics import Accuracy | |
import pytorch_lightning as pl | |
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule | |
import os | |
from ray.tune.integration.pytorch_lightning import TuneReportCallback | |
from ray import air, tune | |
class LightningMNISTClassifier(pl.LightningModule): | |
def __init__(self, config, data_dir=None): | |
super(LightningMNISTClassifier, self).__init__() | |
self.data_dir = data_dir or os.getcwd() | |
self.lr = config["lr"] | |
layer_1, layer_2 = config["layer_1"], config["layer_2"] | |
self.batch_size = config["batch_size"] | |
# mnist images are (1, 28, 28) (channels, width, height) | |
self.layer_1 = torch.nn.Linear(28 * 28, layer_1) | |
self.layer_2 = torch.nn.Linear(layer_1, layer_2) | |
self.layer_3 = torch.nn.Linear(layer_2, 10) | |
self.accuracy = Accuracy(task="multiclass", num_classes=10) | |
def forward(self, x): | |
batch_size, channels, width, height = x.size() | |
x = x.view(batch_size, -1) | |
x = self.layer_1(x) | |
x = torch.relu(x) | |
x = self.layer_2(x) | |
x = torch.relu(x) | |
x = self.layer_3(x) | |
x = torch.log_softmax(x, dim=1) | |
return x | |
def configure_optimizers(self): | |
return torch.optim.Adam(self.parameters(), lr=self.lr) | |
def training_step(self, train_batch, batch_idx): | |
x, y = train_batch | |
logits = self.forward(x) | |
loss = F.nll_loss(logits, y) | |
acc = self.accuracy(logits, y) | |
self.log("ptl/train_loss", loss) | |
self.log("ptl/train_accuracy", acc) | |
return loss | |
def validation_step(self, val_batch, batch_idx): | |
x, y = val_batch | |
logits = self.forward(x) | |
loss = F.nll_loss(logits, y) | |
acc = self.accuracy(logits, y) | |
return {"val_loss": loss, "val_accuracy": acc} | |
def validation_epoch_end(self, outputs): | |
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() | |
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean() | |
self.log("ptl/val_loss", avg_loss) | |
self.log("ptl/val_accuracy", avg_acc) | |
def train_mnist_tune(config, num_epochs=10, num_gpus=0): | |
data_dir = os.path.abspath("./data") | |
model = LightningMNISTClassifier(config, data_dir) | |
with FileLock(os.path.expanduser("~/.data.lock")): | |
dm = MNISTDataModule( | |
data_dir=data_dir, num_workers=1, batch_size=config["batch_size"] | |
) | |
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"} | |
trainer = pl.Trainer( | |
max_epochs=num_epochs, | |
# If fractional GPUs passed in, convert to int. | |
gpus=math.ceil(num_gpus), | |
enable_progress_bar=False, | |
callbacks=[TuneReportCallback(metrics, on="validation_end")], | |
) | |
trainer.fit(model, dm) | |
def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0): | |
config = { | |
"layer_1": tune.choice([32, 64, 128]), | |
"layer_2": tune.choice([64, 128, 256]), | |
"lr": tune.loguniform(1e-4, 1e-1), | |
"batch_size": tune.choice([32, 64, 128]), | |
} | |
trainable = tune.with_parameters( | |
train_mnist_tune, num_epochs=num_epochs, num_gpus=gpus_per_trial | |
) | |
tuner = tune.Tuner( | |
tune.with_resources(trainable, resources={"cpu": 1, "gpu": gpus_per_trial}), | |
tune_config=tune.TuneConfig( | |
metric="loss", | |
mode="min", | |
num_samples=num_samples, | |
), | |
run_config=air.RunConfig( | |
name="tune_mnist", | |
), | |
param_space=config, | |
) | |
results = tuner.fit() | |
print("Best hyperparameters found were: ", results.get_best_result().config) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--smoke-test", action="store_true", help="Finish quickly for testing" | |
) | |
args, _ = parser.parse_known_args() | |
if args.smoke_test: | |
tune_mnist(num_samples=1, num_epochs=1, gpus_per_trial=0) | |
else: | |
tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
absl-py==1.4.0 | |
aiofiles==22.1.0 | |
aiohttp==3.8.4 | |
aiohttp-cors==0.7.0 | |
aiorwlock==1.3.0 | |
aiosignal==1.3.1 | |
aiosqlite==0.19.0 | |
annotated-types==0.5.0 | |
anyio==3.6.2 | |
appnope==0.1.3 | |
argon2-cffi==21.3.0 | |
argon2-cffi-bindings==21.2.0 | |
arrow==1.2.3 | |
asttokens==2.2.1 | |
async-lru==2.0.3 | |
async-timeout==4.0.2 | |
attrs==23.1.0 | |
Babel==2.12.1 | |
backcall==0.2.0 | |
backoff==2.2.1 | |
beautifulsoup4==4.12.2 | |
bleach==6.0.0 | |
blessed==1.20.0 | |
cachetools==5.3.0 | |
certifi==2022.12.7 | |
cffi==1.15.1 | |
charset-normalizer==3.1.0 | |
click==8.1.4 | |
colorful==0.5.5 | |
comm==0.1.3 | |
contourpy==1.0.7 | |
croniter==1.3.15 | |
cycler==0.11.0 | |
dateutils==0.6.12 | |
debugpy==1.6.7 | |
decorator==5.1.1 | |
deepdiff==6.3.1 | |
defusedxml==0.7.1 | |
distlib==0.3.6 | |
executing==1.2.0 | |
fastapi==0.88.0 | |
fastjsonschema==2.16.3 | |
filelock==3.12.0 | |
fonttools==4.39.3 | |
fqdn==1.5.1 | |
frozenlist==1.3.3 | |
fsspec==2023.4.0 | |
google-api-core==2.11.1 | |
google-auth==2.17.3 | |
google-auth-oauthlib==1.0.0 | |
googleapis-common-protos==1.59.1 | |
gpustat==1.1 | |
grpcio==1.49.1 | |
h11==0.14.0 | |
idna==3.4 | |
imageio==2.28.1 | |
importlib-metadata==6.6.0 | |
importlib-resources==5.12.0 | |
inquirer==3.1.3 | |
ipykernel==6.22.0 | |
ipython==8.12.1 | |
ipython-genutils==0.2.0 | |
ipywidgets==8.0.7 | |
isoduration==20.11.0 | |
itsdangerous==2.1.2 | |
jedi==0.18.2 | |
Jinja2==3.1.2 | |
joblib==1.2.0 | |
json5==0.9.11 | |
jsonpointer==2.3 | |
jsonschema==4.17.3 | |
jupyter-events==0.6.3 | |
jupyter-lsp==2.2.0 | |
jupyter-ydoc==0.2.4 | |
jupyter_client==8.2.0 | |
jupyter_core==5.3.0 | |
jupyter_server==2.5.0 | |
jupyter_server_fileid==0.9.0 | |
jupyter_server_terminals==0.4.4 | |
jupyter_server_ydoc==0.8.0 | |
jupyterlab==4.0.2 | |
jupyterlab-pygments==0.2.2 | |
jupyterlab-widgets==3.0.8 | |
jupyterlab_server==2.22.1 | |
kiwisolver==1.4.4 | |
lazy_loader==0.2 | |
lightning==1.9.5 | |
lightning-bolts==0.7.0 | |
lightning-cloud==0.5.37 | |
lightning-utilities==0.8.0 | |
Markdown==3.4.3 | |
markdown-it-py==3.0.0 | |
MarkupSafe==2.1.2 | |
matplotlib==3.7.1 | |
matplotlib-inline==0.1.6 | |
mdurl==0.1.2 | |
mistune==2.0.5 | |
mpmath==1.3.0 | |
msgpack==1.0.5 | |
multidict==6.0.4 | |
nbclassic==0.5.6 | |
nbclient==0.7.4 | |
nbconvert==7.3.1 | |
nbformat==5.8.0 | |
nest-asyncio==1.5.6 | |
networkx==3.1 | |
nltk==3.8.1 | |
notebook==6.5.4 | |
notebook_shim==0.2.3 | |
numpy==1.24.3 | |
nvidia-ml-py==12.535.77 | |
oauthlib==3.2.2 | |
opencensus==0.11.2 | |
opencensus-context==0.1.3 | |
opencv-python-headless==4.8.0.74 | |
ordered-set==4.1.0 | |
packaging==23.1 | |
pandas==2.0.3 | |
pandocfilters==1.5.0 | |
parso==0.8.3 | |
pexpect==4.8.0 | |
pickleshare==0.7.5 | |
Pillow==9.5.0 | |
pkgutil_resolve_name==1.3.10 | |
platformdirs==3.5.0 | |
prometheus-client==0.16.0 | |
prompt-toolkit==3.0.38 | |
protobuf==4.22.3 | |
psutil==5.9.5 | |
ptyprocess==0.7.0 | |
pure-eval==0.2.2 | |
py-spy==0.3.14 | |
pyarrow==12.0.1 | |
pyasn1==0.5.0 | |
pyasn1-modules==0.3.0 | |
pycparser==2.21 | |
pydantic==1.10.11 | |
Pygments==2.15.1 | |
PyJWT==2.7.0 | |
pyparsing==3.0.9 | |
pyrsistent==0.19.3 | |
python-dateutil==2.8.2 | |
python-editor==1.0.4 | |
python-json-logger==2.0.7 | |
python-multipart==0.0.6 | |
pytorch-lightning==1.9.5 | |
pytz==2023.3 | |
PyWavelets==1.4.1 | |
PyYAML==6.0 | |
pyzmq==25.0.2 | |
ray==2.5.1 | |
readchar==4.0.5 | |
regex==2023.6.3 | |
requests==2.29.0 | |
requests-oauthlib==1.3.1 | |
rfc3339-validator==0.1.4 | |
rfc3986-validator==0.1.1 | |
rich==13.4.2 | |
rsa==4.9 | |
scikit-image==0.21.0 | |
scikit-learn==1.3.0 | |
scipy==1.9.1 | |
seaborn==0.12.2 | |
Send2Trash==1.8.2 | |
six==1.16.0 | |
smart-open==6.3.0 | |
sniffio==1.3.0 | |
soupsieve==2.4.1 | |
stack-data==0.6.2 | |
starlette==0.22.0 | |
starsessions==1.3.0 | |
sympy==1.11.1 | |
tensorboard==2.13.0 | |
tensorboard-data-server==0.7.0 | |
tensorboard-plugin-wit==1.8.1 | |
tensorboardX==2.6.1 | |
terminado==0.17.1 | |
threadpoolctl==3.1.0 | |
tifffile==2023.4.12 | |
tinycss2==1.2.1 | |
tomli==2.0.1 | |
torch==2.0.1 | |
torchaudio==2.0.2 | |
torchmetrics==1.0.0 | |
torchvision==0.15.2 | |
tornado==6.3.1 | |
tqdm==4.65.0 | |
traitlets==5.9.0 | |
typing_extensions==4.7.1 | |
tzdata==2023.3 | |
uri-template==1.2.0 | |
urllib3==1.26.15 | |
uvicorn==0.22.0 | |
virtualenv==20.21.0 | |
wcwidth==0.2.6 | |
webcolors==1.13 | |
webencodings==0.5.1 | |
websocket-client==1.5.1 | |
websockets==11.0.3 | |
Werkzeug==2.3.2 | |
widgetsnbextension==4.0.8 | |
y-py==0.5.9 | |
yarl==1.9.2 | |
ypy-websocket==0.8.2 | |
zipp==3.15.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment