Created
October 5, 2022 03:15
-
-
Save jaymody/a14f04813243e84a9360c3095d9da474 to your computer and use it in GitHub Desktop.
Comparing jax code runtimes with jax.array vs np.array
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
def forward_fn(params, X): | |
for W, b in params[:-1]: | |
X = jax.nn.relu(X @ W + b) | |
final_W, final_b = params[-1] | |
return X @ final_W + final_b | |
def initialize_params(key, input_dim, hidden_dims, output_dim): | |
sizes = [input_dim] + hidden_dims + [output_dim] | |
keys = jax.random.split(key, len(sizes) - 1) | |
return [ | |
(jax.random.normal(k, (n_in, n_out)), jnp.zeros((n_out,))) | |
for k, n_in, n_out in zip(keys, sizes[:-1], sizes[1:]) | |
] | |
def loss_fn(params, X, y): | |
# forward pass | |
unnormalized_probs = forward_fn(params, X) | |
# cross entropy loss | |
batch_size = unnormalized_probs.shape[0] | |
num_classes = unnormalized_probs.shape[-1] | |
log_probs = jax.nn.log_softmax(unnormalized_probs, axis=-1) | |
labels = jax.nn.one_hot(y, num_classes) | |
loss = jnp.sum(labels * -log_probs) / batch_size | |
return loss | |
@jax.jit | |
def update(params, X, y, lr): | |
# compute loss and gradient | |
loss, grad = jax.value_and_grad(loss_fn)(params, X, y) | |
# good ole vanilla stochastic gradient descent | |
params = jax.tree_map(lambda w, g: w - lr * g, params, grad) | |
return loss, params | |
def train(params, X, y, batch_size, lr): | |
for i in range(0, len(X), batch_size): | |
loss, params = update( | |
params, | |
X[i : i + batch_size], | |
y[i : i + batch_size], | |
lr, | |
) | |
# print(f"loss at step {i} = {loss}") | |
def main(conversion): | |
import random | |
import time | |
import numpy as np | |
# create dummy data to simulate mnist | |
dummy_X = [[random.random() for _ in range(784)] for _ in range(60000)] | |
dummy_y = [random.randint(0, 9) for _ in range(60000)] | |
# test convert and train times | |
conversions = { | |
"np.array": lambda x: np.array(x), | |
"np.asarray": lambda x: np.asarray(x), | |
"jnp.array": lambda x: jnp.array(x), | |
"jnp.array + np.array": lambda x: jnp.array(np.array(x)), | |
"jnp.array + np.asarray": lambda x: jnp.array(np.asarray(x)), | |
"jnp.asarray": lambda x: jnp.asarray(x), | |
"jnp.asarray + np.array": lambda x: jnp.asarray(np.array(x)), | |
"jnp.asarray + np.asarray": lambda x: jnp.asarray(np.asarray(x)), | |
} | |
conversion_func = conversions[conversion] | |
# initialize params | |
params = initialize_params(jax.random.PRNGKey(123), 784, [128, 64], 10) | |
# convert to np or jax array | |
convert_start_time = time.time() | |
X, y = conversion_func(dummy_X), conversion_func(dummy_y) | |
convert_time = time.time() - convert_start_time | |
# run update at least once so it jit compiles | |
jit_start_time = time.time() | |
update(params, X[:64], y[:64], 1e-3) | |
jit_time = time.time() - jit_start_time | |
# train | |
train_start_time = time.time() | |
train(params, X, y, 64, 1e-3) | |
train_time = time.time() - train_start_time | |
print("conversion =", conversion) | |
print("convert_times =", convert_time) | |
print("jit_time =", jit_time) | |
print("train_times =", train_time) | |
if __name__ == "__main__": | |
import sys | |
main(sys.argv[1]) | |
## no jit | |
# conversion = np.array | |
# convert_times = 1.390981674194336 | |
# train_times = 12.141575813293457 | |
# conversion = np.asarray | |
# convert_times = 1.2834157943725586 | |
# train_times = 11.19111704826355 | |
# conversion = jnp.array | |
# convert_times = 95.75779509544373 | |
# train_times = 21.965157985687256 | |
# conversion = jnp.array + np.array | |
# convert_times = 1.3955588340759277 | |
# train_times = 22.44120502471924 | |
# conversion = jnp.array + np.asarray | |
# convert_times = 1.3871350288391113 | |
# train_times = 22.535207986831665 | |
# conversion = jnp.asarray | |
# convert_times = 89.70011687278748 | |
# train_times = 21.396696090698242 | |
# conversion = jnp.asarray + np.array | |
# convert_times = 1.3765108585357666 | |
# train_times = 21.16973900794983 | |
# conversion = jnp.asarray + np.asarray | |
# convert_times = 1.3343868255615234 | |
# train_times = 21.99706506729126 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment