Skip to content

Instantly share code, notes, and snippets.

@riveSunder
Created September 3, 2020 20:12
Show Gist options
  • Save riveSunder/22d0ea1d4cdc63f5431a33af42ec94f6 to your computer and use it in GitHub Desktop.
Save riveSunder/22d0ea1d4cdc63f5431a33af42ec94f6 to your computer and use it in GitHub Desktop.
Benchmarks comparing Autograd, JAX, and PyTorch for matmuls and convolutions, with emphasis on RL agent mock rollouts.
matmuls ag: 1.96e-03 s, JAX: 7.68e-01, JAX no jit: 1.07e-01. n = 1, dim_x = 64, steps = 1000
matmuls ag: 1.98e-03 s, JAX: 2.89e-01, JAX no jit: 1.09e-01. n = 1, dim_x = 256, steps = 1000
matmuls ag: 2.16e-03 s, JAX: 2.96e-01, JAX no jit: 1.13e-01. n = 1, dim_x = 1024, steps = 1000
matmuls ag: 2.71e-03 s, JAX: 2.97e-01, JAX no jit: 1.14e-01. n = 1, dim_x = 4096, steps = 1000
matmuls ag: 3.44e-03 s, JAX: 3.51e-01, JAX no jit: 1.14e-01. n = 1, dim_x = 8192, steps = 1000
matmuls ag: 2.23e-03 s, JAX: 2.43e-01, JAX no jit: 1.07e-01. n = 32, dim_x = 64, steps = 1000
matmuls ag: 2.94e-03 s, JAX: 1.83e-01, JAX no jit: 1.09e-01. n = 32, dim_x = 256, steps = 1000
matmuls ag: 7.45e-03 s, JAX: 2.41e-01, JAX no jit: 1.31e-01. n = 32, dim_x = 1024, steps = 1000
matmuls ag: 1.27e-02 s, JAX: 2.42e-01, JAX no jit: 1.11e-01. n = 32, dim_x = 4096, steps = 1000
matmuls ag: 2.21e-02 s, JAX: 2.46e-01, JAX no jit: 1.45e-01. n = 32, dim_x = 8192, steps = 1000
matmuls ag: 1.95e-02 s, JAX: 2.38e-01, JAX no jit: 1.30e-01. n = 512, dim_x = 64, steps = 1000
matmuls ag: 2.65e-02 s, JAX: 2.47e-01, JAX no jit: 1.11e-01. n = 512, dim_x = 256, steps = 1000
matmuls ag: 3.80e-02 s, JAX: 1.96e-01, JAX no jit: 1.11e-01. n = 512, dim_x = 1024, steps = 1000
matmuls ag: 9.57e-02 s, JAX: 1.95e-01, JAX no jit: 1.27e-01. n = 512, dim_x = 4096, steps = 1000
matmuls ag: 1.69e-01 s, JAX: 2.34e-01, JAX no jit: 1.39e-01. n = 512, dim_x = 8192, steps = 1000
matmuls ag: 3.39e-02 s, JAX: 1.97e-01, JAX no jit: 1.23e-01. n = 4096, dim_x = 64, steps = 1000
matmuls ag: 7.33e-02 s, JAX: 2.20e-01, JAX no jit: 1.22e-01. n = 4096, dim_x = 256, steps = 1000
matmuls ag: 1.78e-01 s, JAX: 2.83e-01, JAX no jit: 1.43e-01. n = 4096, dim_x = 1024, steps = 1000
matmuls ag: 6.86e-01 s, JAX: 5.85e-01, JAX no jit: 4.29e-01. n = 4096, dim_x = 4096, steps = 1000
matmuls ag: 5.35e+00 s, JAX: 9.36e-01, JAX no jit: 8.28e-01. n = 4096, dim_x = 8192, steps = 1000
matmuls ag: 5.09e-02 s, JAX: 1.91e-01, JAX no jit: 1.11e-01. n = 8192, dim_x = 64, steps = 1000
matmuls ag: 1.06e-01 s, JAX: 1.93e-01, JAX no jit: 1.30e-01. n = 8192, dim_x = 256, steps = 1000
matmuls ag: 2.91e-01 s, JAX: 3.46e-01, JAX no jit: 2.40e-01. n = 8192, dim_x = 1024, steps = 1000
matmuls ag: 5.29e+00 s, JAX: 9.32e-01, JAX no jit: 8.28e-01. n = 8192, dim_x = 4096, steps = 1000
matmuls ag: 1.07e+01 s, JAX: 1.78e+00, JAX no jit: 1.63e+00. n = 8192, dim_x = 8192, steps = 1000
pybullet build time: Jul 8 2020 18:23:32
/home/rlygr/.venv/jax/lib/python3.6/site-packages/gym/logger.py:30: UserWarning: WARN: Box bound precision lowered by casting to float32
warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))
env InvertedPendulumBulletEnv-v0 completed 1000 steps in 1.612466812133789 s by <class '__main__.JAXDummy'>
env InvertedPendulumBulletEnv-v0 completed 1000 steps in 0.10097122192382812 s by <class '__main__.TorchDummy'>
env InvertedPendulumBulletEnv-v0 completed 1000 steps in 0.1425793170928955 s by <class '__main__.TorchNNDummy'>
env InvertedPendulumBulletEnv-v0 completed 1000 steps in 0.06732368469238281 s by <class '__main__.AutogradDummy'>
env InvertedPendulumBulletEnv-v0 completed 1000 steps in 0.05839872360229492 s by <class '__main__.NPDummy'>
env InvertedDoublePendulumBulletEnv-v0 completed 1000 steps in 1.2599716186523438 s by <class '__main__.JAXDummy'>
env InvertedDoublePendulumBulletEnv-v0 completed 1000 steps in 0.11095952987670898 s by <class '__main__.TorchDummy'>
env InvertedDoublePendulumBulletEnv-v0 completed 1000 steps in 0.15032553672790527 s by <class '__main__.TorchNNDummy'>
env InvertedDoublePendulumBulletEnv-v0 completed 1000 steps in 0.07796144485473633 s by <class '__main__.AutogradDummy'>
env InvertedDoublePendulumBulletEnv-v0 completed 1000 steps in 0.06907510757446289 s by <class '__main__.NPDummy'>
env AntBulletEnv-v0 completed 1000 steps in 7.337543725967407 s by <class '__main__.JAXDummy'>
env AntBulletEnv-v0 completed 1000 steps in 0.6920044422149658 s by <class '__main__.TorchDummy'>
env AntBulletEnv-v0 completed 1000 steps in 0.7649385929107666 s by <class '__main__.TorchNNDummy'>
env AntBulletEnv-v0 completed 1000 steps in 0.6465122699737549 s by <class '__main__.AutogradDummy'>
env AntBulletEnv-v0 completed 1000 steps in 0.6358811855316162 s by <class '__main__.NPDummy'>
env HumanoidBulletEnv-v0 completed 1000 steps in 14.182511568069458 s by <class '__main__.JAXDummy'>
env HumanoidBulletEnv-v0 completed 1000 steps in 1.0015196800231934 s by <class '__main__.TorchDummy'>
env HumanoidBulletEnv-v0 completed 1000 steps in 1.1058030128479004 s by <class '__main__.TorchNNDummy'>
env HumanoidBulletEnv-v0 completed 1000 steps in 0.9539287090301514 s by <class '__main__.AutogradDummy'>
env HumanoidBulletEnv-v0 completed 1000 steps in 0.9695944786071777 s by <class '__main__.NPDummy'>
env procgen:procgen-coinrun-v0 completed 1000 steps in 6.126019239425659 s by <class '__main__.TorchConvDummy'>
env procgen:procgen-coinrun-v0 completed 1000 steps in 234.86034059524536 s by <class '__main__.AutogradConvDummy'>
env procgen:procgen-coinrun-v0 completed 1000 steps in 5.710960626602173 s by <class '__main__.JAXConvDummy'>
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import reduce
import time
from autograd import numpy as anp
import autograd.scipy.signal
convolve = autograd.scipy.signal.convolve
import jax.lax
import jax.scipy.signal
#j_convolve = jax.scipy.signal.convolve
j_convolve = jax.lax.conv
from jax import numpy as jnp
from jax import jit
import jax.random as jpr
import numpy.random as npr
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import pybullet_envs
import matplotlib.pyplot as plt
def sigmoid(x):
return 1 / (1 + anp.exp(-x))
def relu(x):
return x
class NPDummy():
def __init__(self, input_dim, output_dim, hid_dim=[32,32]):
self.dim_x = input_dim
self.dim_y = output_dim
self.dim_h = hid_dim
self.act = [relu] * len(self.dim_h)
self.act.append(jnp.tanh)
self.init_parameters()
def init_parameters(self):
self.layers = []
self.layers.append(1.e-1 * npr.randn(self.dim_x, self.dim_h[0]))
for ii in range(1,len(self.dim_h)-1):
self.layers.append(1.e-1 * npr.randn(self.dim_h[ii-1], self.dim_h[ii]))
self.layers.append(1.e-1 * npr.randn(self.dim_h[-1], self.dim_y))
def forward(self,x):
return npr.randn(self.dim_y,)
class JAXDummy(NPDummy):
def __init__(self, input_dim, output_dim, hid_dim=[32,32]):
super(JAXDummy, self).__init__(input_dim, output_dim, hid_dim)
self.act = [jnp.tanh] * len(self.dim_h)
self.act.append(jnp.tanh)
#@jit
def forward(self,x):
for jj, layer in enumerate(self.layers):
x = self.act[jj](jnp.matmul(x, layer))
return x
class AutogradDummy(NPDummy):
def __init__(self, input_dim, output_dim, hid_dim=[32,32]):
super(AutogradDummy, self).__init__(input_dim, output_dim, hid_dim)
self.act = [anp.tanh] * len(self.dim_h)
self.act.append(anp.tanh)
def forward(self,x):
for jj, layer in enumerate(self.layers):
x = self.act[jj](anp.matmul(x, layer))
return x
class TorchDummy(NPDummy):
def __init__(self, input_dim, output_dim, hid_dim=[32,32]):
super(TorchDummy, self).__init__(input_dim, output_dim, hid_dim)
self.act = [torch.tanh] * len(self.dim_h)
self.act.append(torch.tanh)
def init_parameters(self):
self.layers = []
self.layers.append(1.e-1 * torch.randn(self.dim_x, self.dim_h[0]))
for ii in range(1,len(self.dim_h)-1):
self.layers.append(1.e-1 * torch.randn(self.dim_h[ii-1], self.dim_h[ii]))
self.layers.append(1.e-1 * torch.randn(self.dim_h[-1], self.dim_y))
def forward(self,x):
x = torch.Tensor(x)
for jj, layer in enumerate(self.layers):
x = self.act[jj](torch.matmul(x, layer))
return x.detach().numpy()
class TorchNNDummy(NPDummy):
def __init__(self, input_dim, output_dim, hid_dim=[32,32]):
super(TorchNNDummy, self).__init__(input_dim, output_dim, hid_dim)
self.act = [torch.tanh] * len(self.dim_h)
self.act.append(torch.tanh)
def init_parameters(self):
self.layers = nn.Sequential(nn.Linear(self.dim_x, self.dim_h[0], bias=False),\
nn.Tanh())
for ii in range(1,len(self.dim_h)-1):
self.layers.add_module("layer{}".format(ii), \
nn.Linear(self.dim_h[ii-1], self.dim_h[ii], bias=False))
self.layers.add_module("act{}".format(ii),\
nn.Tanh())
self.layers.add_module("endlayer",\
nn.Linear(self.dim_h[-1], self.dim_y, bias=False))
self.layers.add_module("endact",\
nn.Tanh())
def forward(self,x):
x = torch.Tensor(x)
return self.layers(x).detach().numpy()
class TorchConvDummy():
def __init__(self, input_dim, output_dim, hid_dim=[32,32,32,32]):
self.input_dim = input_dim
self.output_dim = output_dim
self.dim_h = hid_dim
self.init_parameters()
def init_parameters(self):
self.layers = nn.Sequential(nn.Conv2d(3, self.dim_h[0], 3, stride=2, \
padding=1, bias=False), nn.Tanh())
for ii in range(1,len(self.dim_h)):
self.layers.add_module("layer{}".format(ii), \
nn.Conv2d(self.dim_h[ii-1], self.dim_h[ii], 3, stride=2, \
padding=1, bias=False))
self.layers.add_module("act{}".format(ii),\
nn.Tanh())
self.layers.add_module("flattener", nn.Flatten())
self.layers.add_module("endlayer",\
nn.Linear(512, self.output_dim, bias=False))
self.layers.add_module("endact",\
nn.Softmax(dim=-1))
def forward(self,x):
x = torch.Tensor(x).permute(2,0,1).unsqueeze(0)
return self.layers(x)
def get_action(self,x):
x = self.forward(x)
action = torch.argmax(x)
return action.detach().numpy()
class AutogradConvDummy():
def __init__(self, input_dim, output_dim, hid_dim=[32,32,32,32]):
self.input_dim = input_dim
self.output_dim = output_dim
self.dim_h = hid_dim
self.init_parameters()
def init_parameters(self):
self.kernels = []
self.kernels.append(1e-1 * npr.randn(3, self.dim_h[0], 3, 3))
for ii in range(1, len(self.dim_h)):
self.kernels.append(1e-1 * npr.randn(\
self.dim_h[ii-1], self.dim_h[ii], 3, 3))
self.dense = 1e-1 * npr.randn(100352, 15)
def softmax(self, x):
x = x - anp.max(x)
return anp.exp(x) / anp.sum(anp.exp(x), axis=-1)
def forward(self, x):
x = x.transpose(2,0,1)[anp.newaxis,:,:,:]
for jj, kernel in enumerate(self.kernels):
x = anp.tanh(convolve(x, kernel, axes=([2,3], [2,3]), \
dot_axes=([1], [0]), mode="valid"))
x = x.ravel()
x = self.softmax(anp.matmul(x, self.dense))
return x
def get_action(self, x):
x = self.forward(x)
action = anp.argmax(x, axis=-1)
return action
class JAXConvDummy():
def __init__(self, input_dim, output_dim, hid_dim=[32,32,32,32]):
self.input_dim = input_dim
self.output_dim = output_dim
self.dim_h = hid_dim
self.init_parameters()
def init_parameters(self):
self.kernels = []
self.kernels.append(1e-1 * npr.randn(self.dim_h[0], 3, 3, 3))
for ii in range(1, len(self.dim_h)):
self.kernels.append(1e-1 * npr.randn(\
self.dim_h[ii-1], self.dim_h[ii], 3, 3))
self.dense = 1e-1 * npr.randn(100352, 15)
def softmax(self, x):
x = x - jnp.max(x)
return jnp.exp(x) / jnp.sum(jnp.exp(x), axis=-1)
def forward(self, x):
x = jnp.array(x.transpose(2,0,1)[jnp.newaxis,:,:,:], dtype=jnp.float32)
for jj, kernel in enumerate(self.kernels):
x = jnp.tanh(j_convolve(x, kernel, window_strides=(1,1), padding="VALID"))
x = x.ravel()
x = self.softmax(jnp.matmul(x, self.dense))
return x
def get_action(self, x):
x = self.forward(x)
action = jnp.argmax(x, axis=-1)
return action
def jax_matmuls(dim_n=256, dim_x=64, num_times = 1000, use_jit=True):
prng_key = jpr.PRNGKey(1)
x = jpr.normal(prng_key, (dim_n, dim_x))
w = jpr.normal(prng_key, (dim_x,1))
if use_jit:
for ii in range(num_times):
_ = jax_jit_matmul(x, w)
else:
for ii in range(num_times):
_ = jax_nojit_matmul(x, w)
@jit
def jax_jit_matmul(x,w):
return jnp.matmul(x,w)
def jax_nojit_matmul(x,w):
return jnp.matmul(x,w)
def ag_matmuls(dim_n=256, dim_x=64, num_times = 1000):
x = npr.randn(dim_n, dim_x)
w = npr.randn(dim_x,1)
for ii in range(num_times):
_ = anp.matmul(x,w)
if __name__ == "__main__":
max_steps = 1000
if(1): #skip this section if you only want to benchmark convolutions
for dim_n in [1, 32, 512, 4096, 8192]:
for dim_x in [64, 256, 1024, 4096, 8192]:
t00 = time.time()
jax_matmuls(dim_n=dim_n, dim_x=dim_x, num_times=max_steps)
t11 = time.time()
ag_matmuls(dim_n=dim_n, dim_x=dim_x, num_times=max_steps)
t22 = time.time()
jax_matmuls(dim_n=dim_n, dim_x=dim_x, num_times=max_steps, use_jit=False)
t33 = time.time()
print("matmuls ag: {:.2e} s, JAX: {:.2e}, JAX no jit: {:.2e}. n = {}, dim_x = {}, steps = {}"\
.format(t22-t11, t11-t00, t33-t22, dim_n, dim_x, max_steps))
for env_name in ["InvertedPendulumBulletEnv-v0",\
"InvertedDoublePendulumBulletEnv-v0",\
"AntBulletEnv-v0",\
"HumanoidBulletEnv-v0"]:
env = gym.make(env_name)
input_dim = env.observation_space.sample().shape[0]
output_dim = env.action_space.sample().shape[0]
for agent_fn in [JAXDummy, TorchDummy, TorchNNDummy, AutogradDummy, NPDummy]:
steps = 0
model = agent_fn(input_dim=input_dim, output_dim=output_dim, hid_dim=[32,32])
t0 = time.time()
done = True
while steps < max_steps:
if done:
obs = env.reset()
done = False
obs, reward, done, info = env.step(model.forward(obs))
steps += 1
t1 = time.time()
print("env {} completed {} steps in {} s by ".format(env_name, steps, t1-t0), agent_fn)
env.close()
for env_name in ["procgen:procgen-coinrun-v0"]:
env = gym.make(env_name)
input_dim = env.observation_space.sample().shape[0]
output_dim = 15
for agent_fn in [TorchConvDummy, AutogradConvDummy, JAXConvDummy]:
steps = 0
model = agent_fn(input_dim=input_dim, output_dim=output_dim)
t0 = time.time()
done = True
while steps < max_steps:
if done:
obs = env.reset()
done = False
obs, reward, done, info = env.step(model.get_action(obs))
steps += 1
t1 = time.time()
print("env {} completed {} steps in {} s by ".format(env_name, steps, t1-t0), agent_fn)
env.close()
absl-py==0.9.0
atari-py==0.2.6
autograd==1.3
cffi==1.14.0
cloudpickle==1.3.0
cycler==0.10.0
filelock==3.0.12
future==0.18.2
glcontext==2.2.0
glfw==1.12.0
gym==0.17.2
gym3==0.3.3
imageio==2.9.0
imageio-ffmpeg==0.3.0
jax==0.1.73
jaxlib @ https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl
kiwisolver==1.2.0
matplotlib==3.3.0
moderngl==5.6.1
numpy==1.19.1
opencv-python==4.3.0.36
opt-einsum==3.3.0
Pillow==7.2.0
pkg-resources==0.0.0
procgen==0.10.4
pybullet==2.8.4
pycparser==2.20
pyglet==1.5.0
pyparsing==2.4.7
python-dateutil==2.8.1
scipy==1.5.1
six==1.15.0
torch==1.5.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment