Skip to content

Instantly share code, notes, and snippets.

@alexbw
Last active May 4, 2017 15:45
Show Gist options
  • Save alexbw/7b2a0682f65dd1bcb7120ca2d47a2823 to your computer and use it in GitHub Desktop.
Save alexbw/7b2a0682f65dd1bcb7120ca2d47a2823 to your computer and use it in GitHub Desktop.
from pylab import *
from autograd import grad
import autograd.numpy as np
import torch
from torch.autograd import Variable
from memory_profiler import memory_usage
batch_size = 16
D = 2**10
x = 0.01 * np.random.randn(batch_size,D).astype('float32')
W1 = 0.01 * np.random.randn(D,D).astype('float32')
b1 = 0.01 * np.random.randn(D).astype('float32')
Wout = 0.01 * np.random.randn(D,1).astype('float32')
bout = 0.01 * np.random.randn(1).astype('float32')
l = (np.random.rand(batch_size,1) > 0.5).astype(np.float32)
n = 50
# Autograd
def autograd_rnn(params, x, label, n):
W, b, Wout, bout = params
h1 = x
for i in range(n):
h1 = np.tanh(np.dot(h1, W) + b)
logit = np.dot(h1, Wout) + bout
loss = -np.sum(label * logit - (
logit + np.log(1 + np.exp(-logit))))
return loss
grad_rnn = grad(autograd_rnn)
def m():
grad_rnn((W1,b1,Wout,bout),x,l,n=n)
mem_usage_autograd = np.array(memory_usage(m,interval=0.01))
mem_usage_autograd -= mem_usage_autograd[0]
# PyTorch
tx = Variable(torch.from_numpy(x),requires_grad=False)
tW1 = Variable(torch.from_numpy(W1),requires_grad=True)
tb1 = Variable(torch.from_numpy(b1),requires_grad=True)
tWout = Variable(torch.from_numpy(Wout),requires_grad=True)
tbout = Variable(torch.from_numpy(bout),requires_grad=True)
tl = Variable(torch.from_numpy(l))
def torch_rnn(x,W,b,Wout,bout,label,n):
h1 = x
for i in range(n):
h1 = torch.tanh(torch.mm(h1,W) + torch.unsqueeze(b, 0).expand(x.size(0), b.size(0)))
logit = torch.mm(h1,Wout) + bout.expand(h1.size()[0])
loss = -torch.sum(label * logit - (
logit + torch.log(1 + torch.exp(-logit))))
loss.backward()
return loss, [W.grad,b.grad,Wout.grad,bout.grad]
def m():
torch_rnn(tx,tW1,tb1,tWout,tbout,tl,n)
mem_usage_torch = np.array(memory_usage(m,interval=0.01))
mem_usage_torch -= mem_usage_torch[0]
clf()
plot(mem_usage_autograd, label="Autograd")
plot(mem_usage_torch, label="PyTorch")
ylabel("Memory Usage(MB)")
xlabel("Time (sec)")
xticks(xticks()[0],xticks()[0]/100)
legend()
show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment