Skip to content

Instantly share code, notes, and snippets.

@Rocketknight1
Created January 27, 2021 19:53
Show Gist options
  • Save Rocketknight1/8b3ec932995be18220e869c47a8d40e7 to your computer and use it in GitHub Desktop.
Save Rocketknight1/8b3ec932995be18220e869c47a8d40e7 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow_addons.layers import embeddingbag
import numpy as np
from time import perf_counter
@tf.function
def composite_run(indices, values, weights):
with tf.GradientTape() as tape:
tape.watch(values)
tape.watch(weights)
out = embeddingbag(indices, values, weights)
grads = tape.gradient(out, [values, weights])
return out, grads
@tf.function
def sparse_dense_matmul_run(indices, values, weights):
# Sparse tensor has shape batch_size, num_embeddings and contains the weights
# Dense tensor is the embeddings
# output tensor is the outputs
with tf.GradientTape() as tape:
tape.watch(values)
tape.watch(weights)
indices_1 = tf.range(0, tf.shape(indices)[0], dtype=tf.int64)
indices_1 = tf.reshape(indices_1, (-1, 1))
indices_1 = tf.tile(indices_1, [1, tf.shape(indices)[-1]])
indices_1 = tf.reshape(indices_1, (-1, 1))
indices_2 = tf.reshape(indices, (-1, 1))
sp_indices = tf.concat([indices_1, indices_2], axis=1)
sp_values = tf.reshape(weights, (-1,))
sp_matrix = tf.SparseTensor(sp_indices, sp_values, dense_shape=(tf.shape(indices)[0], tf.shape(values)[0]))
out = tf.sparse.sparse_dense_matmul(sp_matrix, values)
grads = tape.gradient(out, [values, weights])
return out, grads
def main():
indices = np.random.randint(low=0, high=(512 * 512) - 1, size=(64*256, 128)).astype(np.int64)
values = np.random.rand(512 * 512, 256).astype(np.float32)
weights = np.random.rand(*indices.shape).astype(np.float32)
indices = tf.convert_to_tensor(indices)
values = tf.convert_to_tensor(values)
weights = tf.convert_to_tensor(weights)
print("Testing composite run...")
composite_out, composite_grads = composite_run(indices, values, weights)
print("Testing sparse run...")
sparse_out, sparse_grads = sparse_dense_matmul_run(indices, values, weights)
print('Running with composite op...')
start = perf_counter()
for _ in range(50):
out, grads = composite_run(indices, values, weights)
end = perf_counter()
print(f"50 runs took {end - start} seconds")
print('Running with sparse-dense matmul...')
start = perf_counter()
for _ in range(50):
out, grads = sparse_dense_matmul_run(indices, values, weights)
end = perf_counter()
print(f"50 runs took {end - start} seconds")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment