Skip to content

Instantly share code, notes, and snippets.

@Chris-hughes10
Last active November 24, 2021 10:53
Show Gist options
  • Save Chris-hughes10/8ca0e42fce16aff33166d4b4c153e86b to your computer and use it in GitHub Desktop.
Save Chris-hughes10/8ca0e42fce16aff33166d4b4c153e86b to your computer and use it in GitHub Desktop.
pytorch-accelerated_blog_mnist_quickstart
# this example is taken from
# https://github.com/Chris-hughes10/pytorch-accelerated/blob/main/examples/train_mnist.py
import os
from torch import nn, optim
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from pytorch_accelerated import Trainer
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.Linear(in_features=784, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=64),
nn.ReLU(),
nn.Linear(in_features=64, out_features=10),
)
def forward(self, x):
return self.main(x.view(x.shape[0], -1))
def main():
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_dataset, validation_dataset, test_dataset = random_split(
dataset, [50000, 5000, 5000]
)
model = MNISTModel()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_func = nn.CrossEntropyLoss()
trainer = Trainer(
model,
loss_func=loss_func,
optimizer=optimizer,
)
trainer.train(
train_dataset=train_dataset,
eval_dataset=validation_dataset,
num_epochs=2,
per_device_batch_size=32,
)
trainer.evaluate(
dataset=test_dataset,
per_device_batch_size=64,
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment