Skip to content

Instantly share code, notes, and snippets.

@yovasx2
Last active June 15, 2019 21:03
Show Gist options
  • Save yovasx2/3a33aa20fa100e79ce7d181b426df26e to your computer and use it in GitHub Desktop.
Save yovasx2/3a33aa20fa100e79ce7d181b426df26e to your computer and use it in GitHub Desktop.
# TODO: Define your network architecture here
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
import helper
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Download and load the training data
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
images, labels = next(iter(trainloader))
images.resize_(64, 784)
model = nn.Sequential(nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10),
nn.LogSoftmax(dim=1))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.003)
epochs = 50
for e in range(epochs):
running_loss = 0
for images, labels in trainloader:
# Flatten MNIST images into a 784 long vector
images = images.view(images.shape[0], -1)
# TODO: Training pass
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
else:
print(f"Step {e} => Training loss: {running_loss/len(trainloader)}")
# Use the training
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import helper
import torch.nn.functional as F
# Test out your network!
dataiter = iter(testloader)
images, labels = dataiter.next()
img = images[0]
# Convert 2D image to 1D vector
img = img.resize_(1, 784)
# TODO: Calculate the class probabilities (softmax) for img
with torch.no_grad():
logits = model.forward(img)
# Output of the network are logits, need to take softmax for probabilities
ps = F.softmax(logits, dim=1)
# Plot the image and probabilities
helper.view_classify(img.resize_(1, 28, 28), ps, version='Fashion')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment