Skip to content

Instantly share code, notes, and snippets.

@BrambleXu
Created January 18, 2020 05:01
Show Gist options
  • Save BrambleXu/b901858e5ede6ef8ccdf468b69fe6e85 to your computer and use it in GitHub Desktop.
Save BrambleXu/b901858e5ede6ef8ccdf468b69fe6e85 to your computer and use it in GitHub Desktop.
"""
pip install torch==1.4.0 torchvision==0.5.0 tensorboard==2.1.0
command:
python tensorboard_epoch_demo.py
tensorboard --logdir=runs
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
torch.set_printoptions(linewidth=120) # Display options for output
torch.set_grad_enabled(True) # Already on by default
from torch.utils.tensorboard import SummaryWriter # new
print(torch.__version__)
print(torchvision.__version__)
def get_num_correct(preds, labels):
return preds.argmax(dim=1).eq(labels).sum().item()
class Network(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10)
def forward(self, t):
t = F.relu(self.conv1(t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = F.relu(self.conv2(t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = t.flatten(start_dim=1) # t = t.reshape(-1, 12 * 4 * 4)
t = F.relu(self.fc1(t))
t = F.relu(self.fc2(t))
t = self.out(t)
return t
train_set = torchvision.datasets.FashionMNIST(
root='./data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor()
])
)
network = Network()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100, shuffle=True)
optimizer = optim.Adam(network.parameters(), lr=0.01)
images, labels = next(iter(train_loader))
grid = torchvision.utils.make_grid(images)
# Tensorboard lines
tb = SummaryWriter()
tb.add_image('images', grid)
tb.add_graph(network, images)
for epoch in range(10):
total_loss = 0
total_correct = 0
for batch in train_loader: # Get Batch
images, labels = batch
preds = network(images) # Pass Batch
loss = F.cross_entropy(preds, labels) # Calculating Loss
optimizer.zero_grad()
loss.backward() # Calculating Gradients
optimizer.step() # Update Weights
total_loss += loss.item()
total_correct += get_num_correct(preds, labels)
# add tb lines for each epoch
tb.add_scalar('Loss', total_loss, epoch)
tb.add_scalar('Number Correct', total_correct, epoch)
tb.add_scalar('Accuracy', total_correct / len(train_set), epoch)
tb.add_histogram('conv1.bias', network.conv1.bias, epoch)
tb.add_histogram('conv1.weight', network.conv1.weight, epoch)
tb.add_histogram('conv1.weight.grad' ,network.conv1.weight.grad, epoch)
print('epoch', epoch, 'total_correct:', total_correct, 'loss:', total_loss)
tb.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment