Created
September 3, 2020 20:12
-
-
Save riveSunder/22d0ea1d4cdc63f5431a33af42ec94f6 to your computer and use it in GitHub Desktop.
Benchmarks comparing Autograd, JAX, and PyTorch for matmuls and convolutions, with emphasis on RL agent mock rollouts.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
matmuls ag: 1.96e-03 s, JAX: 7.68e-01, JAX no jit: 1.07e-01. n = 1, dim_x = 64, steps = 1000 | |
matmuls ag: 1.98e-03 s, JAX: 2.89e-01, JAX no jit: 1.09e-01. n = 1, dim_x = 256, steps = 1000 | |
matmuls ag: 2.16e-03 s, JAX: 2.96e-01, JAX no jit: 1.13e-01. n = 1, dim_x = 1024, steps = 1000 | |
matmuls ag: 2.71e-03 s, JAX: 2.97e-01, JAX no jit: 1.14e-01. n = 1, dim_x = 4096, steps = 1000 | |
matmuls ag: 3.44e-03 s, JAX: 3.51e-01, JAX no jit: 1.14e-01. n = 1, dim_x = 8192, steps = 1000 | |
matmuls ag: 2.23e-03 s, JAX: 2.43e-01, JAX no jit: 1.07e-01. n = 32, dim_x = 64, steps = 1000 | |
matmuls ag: 2.94e-03 s, JAX: 1.83e-01, JAX no jit: 1.09e-01. n = 32, dim_x = 256, steps = 1000 | |
matmuls ag: 7.45e-03 s, JAX: 2.41e-01, JAX no jit: 1.31e-01. n = 32, dim_x = 1024, steps = 1000 | |
matmuls ag: 1.27e-02 s, JAX: 2.42e-01, JAX no jit: 1.11e-01. n = 32, dim_x = 4096, steps = 1000 | |
matmuls ag: 2.21e-02 s, JAX: 2.46e-01, JAX no jit: 1.45e-01. n = 32, dim_x = 8192, steps = 1000 | |
matmuls ag: 1.95e-02 s, JAX: 2.38e-01, JAX no jit: 1.30e-01. n = 512, dim_x = 64, steps = 1000 | |
matmuls ag: 2.65e-02 s, JAX: 2.47e-01, JAX no jit: 1.11e-01. n = 512, dim_x = 256, steps = 1000 | |
matmuls ag: 3.80e-02 s, JAX: 1.96e-01, JAX no jit: 1.11e-01. n = 512, dim_x = 1024, steps = 1000 | |
matmuls ag: 9.57e-02 s, JAX: 1.95e-01, JAX no jit: 1.27e-01. n = 512, dim_x = 4096, steps = 1000 | |
matmuls ag: 1.69e-01 s, JAX: 2.34e-01, JAX no jit: 1.39e-01. n = 512, dim_x = 8192, steps = 1000 | |
matmuls ag: 3.39e-02 s, JAX: 1.97e-01, JAX no jit: 1.23e-01. n = 4096, dim_x = 64, steps = 1000 | |
matmuls ag: 7.33e-02 s, JAX: 2.20e-01, JAX no jit: 1.22e-01. n = 4096, dim_x = 256, steps = 1000 | |
matmuls ag: 1.78e-01 s, JAX: 2.83e-01, JAX no jit: 1.43e-01. n = 4096, dim_x = 1024, steps = 1000 | |
matmuls ag: 6.86e-01 s, JAX: 5.85e-01, JAX no jit: 4.29e-01. n = 4096, dim_x = 4096, steps = 1000 | |
matmuls ag: 5.35e+00 s, JAX: 9.36e-01, JAX no jit: 8.28e-01. n = 4096, dim_x = 8192, steps = 1000 | |
matmuls ag: 5.09e-02 s, JAX: 1.91e-01, JAX no jit: 1.11e-01. n = 8192, dim_x = 64, steps = 1000 | |
matmuls ag: 1.06e-01 s, JAX: 1.93e-01, JAX no jit: 1.30e-01. n = 8192, dim_x = 256, steps = 1000 | |
matmuls ag: 2.91e-01 s, JAX: 3.46e-01, JAX no jit: 2.40e-01. n = 8192, dim_x = 1024, steps = 1000 | |
matmuls ag: 5.29e+00 s, JAX: 9.32e-01, JAX no jit: 8.28e-01. n = 8192, dim_x = 4096, steps = 1000 | |
matmuls ag: 1.07e+01 s, JAX: 1.78e+00, JAX no jit: 1.63e+00. n = 8192, dim_x = 8192, steps = 1000 | |
pybullet build time: Jul 8 2020 18:23:32 | |
/home/rlygr/.venv/jax/lib/python3.6/site-packages/gym/logger.py:30: UserWarning: WARN: Box bound precision lowered by casting to float32 | |
warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow')) | |
env InvertedPendulumBulletEnv-v0 completed 1000 steps in 1.612466812133789 s by <class '__main__.JAXDummy'> | |
env InvertedPendulumBulletEnv-v0 completed 1000 steps in 0.10097122192382812 s by <class '__main__.TorchDummy'> | |
env InvertedPendulumBulletEnv-v0 completed 1000 steps in 0.1425793170928955 s by <class '__main__.TorchNNDummy'> | |
env InvertedPendulumBulletEnv-v0 completed 1000 steps in 0.06732368469238281 s by <class '__main__.AutogradDummy'> | |
env InvertedPendulumBulletEnv-v0 completed 1000 steps in 0.05839872360229492 s by <class '__main__.NPDummy'> | |
env InvertedDoublePendulumBulletEnv-v0 completed 1000 steps in 1.2599716186523438 s by <class '__main__.JAXDummy'> | |
env InvertedDoublePendulumBulletEnv-v0 completed 1000 steps in 0.11095952987670898 s by <class '__main__.TorchDummy'> | |
env InvertedDoublePendulumBulletEnv-v0 completed 1000 steps in 0.15032553672790527 s by <class '__main__.TorchNNDummy'> | |
env InvertedDoublePendulumBulletEnv-v0 completed 1000 steps in 0.07796144485473633 s by <class '__main__.AutogradDummy'> | |
env InvertedDoublePendulumBulletEnv-v0 completed 1000 steps in 0.06907510757446289 s by <class '__main__.NPDummy'> | |
env AntBulletEnv-v0 completed 1000 steps in 7.337543725967407 s by <class '__main__.JAXDummy'> | |
env AntBulletEnv-v0 completed 1000 steps in 0.6920044422149658 s by <class '__main__.TorchDummy'> | |
env AntBulletEnv-v0 completed 1000 steps in 0.7649385929107666 s by <class '__main__.TorchNNDummy'> | |
env AntBulletEnv-v0 completed 1000 steps in 0.6465122699737549 s by <class '__main__.AutogradDummy'> | |
env AntBulletEnv-v0 completed 1000 steps in 0.6358811855316162 s by <class '__main__.NPDummy'> | |
env HumanoidBulletEnv-v0 completed 1000 steps in 14.182511568069458 s by <class '__main__.JAXDummy'> | |
env HumanoidBulletEnv-v0 completed 1000 steps in 1.0015196800231934 s by <class '__main__.TorchDummy'> | |
env HumanoidBulletEnv-v0 completed 1000 steps in 1.1058030128479004 s by <class '__main__.TorchNNDummy'> | |
env HumanoidBulletEnv-v0 completed 1000 steps in 0.9539287090301514 s by <class '__main__.AutogradDummy'> | |
env HumanoidBulletEnv-v0 completed 1000 steps in 0.9695944786071777 s by <class '__main__.NPDummy'> | |
env procgen:procgen-coinrun-v0 completed 1000 steps in 6.126019239425659 s by <class '__main__.TorchConvDummy'> | |
env procgen:procgen-coinrun-v0 completed 1000 steps in 234.86034059524536 s by <class '__main__.AutogradConvDummy'> | |
env procgen:procgen-coinrun-v0 completed 1000 steps in 5.710960626602173 s by <class '__main__.JAXConvDummy'> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
from collections import OrderedDict | |
from functools import reduce | |
import time | |
from autograd import numpy as anp | |
import autograd.scipy.signal | |
convolve = autograd.scipy.signal.convolve | |
import jax.lax | |
import jax.scipy.signal | |
#j_convolve = jax.scipy.signal.convolve | |
j_convolve = jax.lax.conv | |
from jax import numpy as jnp | |
from jax import jit | |
import jax.random as jpr | |
import numpy.random as npr | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import gym | |
import pybullet_envs | |
import matplotlib.pyplot as plt | |
def sigmoid(x): | |
return 1 / (1 + anp.exp(-x)) | |
def relu(x): | |
return x | |
class NPDummy(): | |
def __init__(self, input_dim, output_dim, hid_dim=[32,32]): | |
self.dim_x = input_dim | |
self.dim_y = output_dim | |
self.dim_h = hid_dim | |
self.act = [relu] * len(self.dim_h) | |
self.act.append(jnp.tanh) | |
self.init_parameters() | |
def init_parameters(self): | |
self.layers = [] | |
self.layers.append(1.e-1 * npr.randn(self.dim_x, self.dim_h[0])) | |
for ii in range(1,len(self.dim_h)-1): | |
self.layers.append(1.e-1 * npr.randn(self.dim_h[ii-1], self.dim_h[ii])) | |
self.layers.append(1.e-1 * npr.randn(self.dim_h[-1], self.dim_y)) | |
def forward(self,x): | |
return npr.randn(self.dim_y,) | |
class JAXDummy(NPDummy): | |
def __init__(self, input_dim, output_dim, hid_dim=[32,32]): | |
super(JAXDummy, self).__init__(input_dim, output_dim, hid_dim) | |
self.act = [jnp.tanh] * len(self.dim_h) | |
self.act.append(jnp.tanh) | |
#@jit | |
def forward(self,x): | |
for jj, layer in enumerate(self.layers): | |
x = self.act[jj](jnp.matmul(x, layer)) | |
return x | |
class AutogradDummy(NPDummy): | |
def __init__(self, input_dim, output_dim, hid_dim=[32,32]): | |
super(AutogradDummy, self).__init__(input_dim, output_dim, hid_dim) | |
self.act = [anp.tanh] * len(self.dim_h) | |
self.act.append(anp.tanh) | |
def forward(self,x): | |
for jj, layer in enumerate(self.layers): | |
x = self.act[jj](anp.matmul(x, layer)) | |
return x | |
class TorchDummy(NPDummy): | |
def __init__(self, input_dim, output_dim, hid_dim=[32,32]): | |
super(TorchDummy, self).__init__(input_dim, output_dim, hid_dim) | |
self.act = [torch.tanh] * len(self.dim_h) | |
self.act.append(torch.tanh) | |
def init_parameters(self): | |
self.layers = [] | |
self.layers.append(1.e-1 * torch.randn(self.dim_x, self.dim_h[0])) | |
for ii in range(1,len(self.dim_h)-1): | |
self.layers.append(1.e-1 * torch.randn(self.dim_h[ii-1], self.dim_h[ii])) | |
self.layers.append(1.e-1 * torch.randn(self.dim_h[-1], self.dim_y)) | |
def forward(self,x): | |
x = torch.Tensor(x) | |
for jj, layer in enumerate(self.layers): | |
x = self.act[jj](torch.matmul(x, layer)) | |
return x.detach().numpy() | |
class TorchNNDummy(NPDummy): | |
def __init__(self, input_dim, output_dim, hid_dim=[32,32]): | |
super(TorchNNDummy, self).__init__(input_dim, output_dim, hid_dim) | |
self.act = [torch.tanh] * len(self.dim_h) | |
self.act.append(torch.tanh) | |
def init_parameters(self): | |
self.layers = nn.Sequential(nn.Linear(self.dim_x, self.dim_h[0], bias=False),\ | |
nn.Tanh()) | |
for ii in range(1,len(self.dim_h)-1): | |
self.layers.add_module("layer{}".format(ii), \ | |
nn.Linear(self.dim_h[ii-1], self.dim_h[ii], bias=False)) | |
self.layers.add_module("act{}".format(ii),\ | |
nn.Tanh()) | |
self.layers.add_module("endlayer",\ | |
nn.Linear(self.dim_h[-1], self.dim_y, bias=False)) | |
self.layers.add_module("endact",\ | |
nn.Tanh()) | |
def forward(self,x): | |
x = torch.Tensor(x) | |
return self.layers(x).detach().numpy() | |
class TorchConvDummy(): | |
def __init__(self, input_dim, output_dim, hid_dim=[32,32,32,32]): | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
self.dim_h = hid_dim | |
self.init_parameters() | |
def init_parameters(self): | |
self.layers = nn.Sequential(nn.Conv2d(3, self.dim_h[0], 3, stride=2, \ | |
padding=1, bias=False), nn.Tanh()) | |
for ii in range(1,len(self.dim_h)): | |
self.layers.add_module("layer{}".format(ii), \ | |
nn.Conv2d(self.dim_h[ii-1], self.dim_h[ii], 3, stride=2, \ | |
padding=1, bias=False)) | |
self.layers.add_module("act{}".format(ii),\ | |
nn.Tanh()) | |
self.layers.add_module("flattener", nn.Flatten()) | |
self.layers.add_module("endlayer",\ | |
nn.Linear(512, self.output_dim, bias=False)) | |
self.layers.add_module("endact",\ | |
nn.Softmax(dim=-1)) | |
def forward(self,x): | |
x = torch.Tensor(x).permute(2,0,1).unsqueeze(0) | |
return self.layers(x) | |
def get_action(self,x): | |
x = self.forward(x) | |
action = torch.argmax(x) | |
return action.detach().numpy() | |
class AutogradConvDummy(): | |
def __init__(self, input_dim, output_dim, hid_dim=[32,32,32,32]): | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
self.dim_h = hid_dim | |
self.init_parameters() | |
def init_parameters(self): | |
self.kernels = [] | |
self.kernels.append(1e-1 * npr.randn(3, self.dim_h[0], 3, 3)) | |
for ii in range(1, len(self.dim_h)): | |
self.kernels.append(1e-1 * npr.randn(\ | |
self.dim_h[ii-1], self.dim_h[ii], 3, 3)) | |
self.dense = 1e-1 * npr.randn(100352, 15) | |
def softmax(self, x): | |
x = x - anp.max(x) | |
return anp.exp(x) / anp.sum(anp.exp(x), axis=-1) | |
def forward(self, x): | |
x = x.transpose(2,0,1)[anp.newaxis,:,:,:] | |
for jj, kernel in enumerate(self.kernels): | |
x = anp.tanh(convolve(x, kernel, axes=([2,3], [2,3]), \ | |
dot_axes=([1], [0]), mode="valid")) | |
x = x.ravel() | |
x = self.softmax(anp.matmul(x, self.dense)) | |
return x | |
def get_action(self, x): | |
x = self.forward(x) | |
action = anp.argmax(x, axis=-1) | |
return action | |
class JAXConvDummy(): | |
def __init__(self, input_dim, output_dim, hid_dim=[32,32,32,32]): | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
self.dim_h = hid_dim | |
self.init_parameters() | |
def init_parameters(self): | |
self.kernels = [] | |
self.kernels.append(1e-1 * npr.randn(self.dim_h[0], 3, 3, 3)) | |
for ii in range(1, len(self.dim_h)): | |
self.kernels.append(1e-1 * npr.randn(\ | |
self.dim_h[ii-1], self.dim_h[ii], 3, 3)) | |
self.dense = 1e-1 * npr.randn(100352, 15) | |
def softmax(self, x): | |
x = x - jnp.max(x) | |
return jnp.exp(x) / jnp.sum(jnp.exp(x), axis=-1) | |
def forward(self, x): | |
x = jnp.array(x.transpose(2,0,1)[jnp.newaxis,:,:,:], dtype=jnp.float32) | |
for jj, kernel in enumerate(self.kernels): | |
x = jnp.tanh(j_convolve(x, kernel, window_strides=(1,1), padding="VALID")) | |
x = x.ravel() | |
x = self.softmax(jnp.matmul(x, self.dense)) | |
return x | |
def get_action(self, x): | |
x = self.forward(x) | |
action = jnp.argmax(x, axis=-1) | |
return action | |
def jax_matmuls(dim_n=256, dim_x=64, num_times = 1000, use_jit=True): | |
prng_key = jpr.PRNGKey(1) | |
x = jpr.normal(prng_key, (dim_n, dim_x)) | |
w = jpr.normal(prng_key, (dim_x,1)) | |
if use_jit: | |
for ii in range(num_times): | |
_ = jax_jit_matmul(x, w) | |
else: | |
for ii in range(num_times): | |
_ = jax_nojit_matmul(x, w) | |
@jit | |
def jax_jit_matmul(x,w): | |
return jnp.matmul(x,w) | |
def jax_nojit_matmul(x,w): | |
return jnp.matmul(x,w) | |
def ag_matmuls(dim_n=256, dim_x=64, num_times = 1000): | |
x = npr.randn(dim_n, dim_x) | |
w = npr.randn(dim_x,1) | |
for ii in range(num_times): | |
_ = anp.matmul(x,w) | |
if __name__ == "__main__": | |
max_steps = 1000 | |
if(1): #skip this section if you only want to benchmark convolutions | |
for dim_n in [1, 32, 512, 4096, 8192]: | |
for dim_x in [64, 256, 1024, 4096, 8192]: | |
t00 = time.time() | |
jax_matmuls(dim_n=dim_n, dim_x=dim_x, num_times=max_steps) | |
t11 = time.time() | |
ag_matmuls(dim_n=dim_n, dim_x=dim_x, num_times=max_steps) | |
t22 = time.time() | |
jax_matmuls(dim_n=dim_n, dim_x=dim_x, num_times=max_steps, use_jit=False) | |
t33 = time.time() | |
print("matmuls ag: {:.2e} s, JAX: {:.2e}, JAX no jit: {:.2e}. n = {}, dim_x = {}, steps = {}"\ | |
.format(t22-t11, t11-t00, t33-t22, dim_n, dim_x, max_steps)) | |
for env_name in ["InvertedPendulumBulletEnv-v0",\ | |
"InvertedDoublePendulumBulletEnv-v0",\ | |
"AntBulletEnv-v0",\ | |
"HumanoidBulletEnv-v0"]: | |
env = gym.make(env_name) | |
input_dim = env.observation_space.sample().shape[0] | |
output_dim = env.action_space.sample().shape[0] | |
for agent_fn in [JAXDummy, TorchDummy, TorchNNDummy, AutogradDummy, NPDummy]: | |
steps = 0 | |
model = agent_fn(input_dim=input_dim, output_dim=output_dim, hid_dim=[32,32]) | |
t0 = time.time() | |
done = True | |
while steps < max_steps: | |
if done: | |
obs = env.reset() | |
done = False | |
obs, reward, done, info = env.step(model.forward(obs)) | |
steps += 1 | |
t1 = time.time() | |
print("env {} completed {} steps in {} s by ".format(env_name, steps, t1-t0), agent_fn) | |
env.close() | |
for env_name in ["procgen:procgen-coinrun-v0"]: | |
env = gym.make(env_name) | |
input_dim = env.observation_space.sample().shape[0] | |
output_dim = 15 | |
for agent_fn in [TorchConvDummy, AutogradConvDummy, JAXConvDummy]: | |
steps = 0 | |
model = agent_fn(input_dim=input_dim, output_dim=output_dim) | |
t0 = time.time() | |
done = True | |
while steps < max_steps: | |
if done: | |
obs = env.reset() | |
done = False | |
obs, reward, done, info = env.step(model.get_action(obs)) | |
steps += 1 | |
t1 = time.time() | |
print("env {} completed {} steps in {} s by ".format(env_name, steps, t1-t0), agent_fn) | |
env.close() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
absl-py==0.9.0 | |
atari-py==0.2.6 | |
autograd==1.3 | |
cffi==1.14.0 | |
cloudpickle==1.3.0 | |
cycler==0.10.0 | |
filelock==3.0.12 | |
future==0.18.2 | |
glcontext==2.2.0 | |
glfw==1.12.0 | |
gym==0.17.2 | |
gym3==0.3.3 | |
imageio==2.9.0 | |
imageio-ffmpeg==0.3.0 | |
jax==0.1.73 | |
jaxlib @ https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl | |
kiwisolver==1.2.0 | |
matplotlib==3.3.0 | |
moderngl==5.6.1 | |
numpy==1.19.1 | |
opencv-python==4.3.0.36 | |
opt-einsum==3.3.0 | |
Pillow==7.2.0 | |
pkg-resources==0.0.0 | |
procgen==0.10.4 | |
pybullet==2.8.4 | |
pycparser==2.20 | |
pyglet==1.5.0 | |
pyparsing==2.4.7 | |
python-dateutil==2.8.1 | |
scipy==1.5.1 | |
six==1.15.0 | |
torch==1.5.1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment