Created
January 27, 2021 19:53
-
-
Save Rocketknight1/8b3ec932995be18220e869c47a8d40e7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow_addons.layers import embeddingbag | |
import numpy as np | |
from time import perf_counter | |
@tf.function | |
def composite_run(indices, values, weights): | |
with tf.GradientTape() as tape: | |
tape.watch(values) | |
tape.watch(weights) | |
out = embeddingbag(indices, values, weights) | |
grads = tape.gradient(out, [values, weights]) | |
return out, grads | |
@tf.function | |
def sparse_dense_matmul_run(indices, values, weights): | |
# Sparse tensor has shape batch_size, num_embeddings and contains the weights | |
# Dense tensor is the embeddings | |
# output tensor is the outputs | |
with tf.GradientTape() as tape: | |
tape.watch(values) | |
tape.watch(weights) | |
indices_1 = tf.range(0, tf.shape(indices)[0], dtype=tf.int64) | |
indices_1 = tf.reshape(indices_1, (-1, 1)) | |
indices_1 = tf.tile(indices_1, [1, tf.shape(indices)[-1]]) | |
indices_1 = tf.reshape(indices_1, (-1, 1)) | |
indices_2 = tf.reshape(indices, (-1, 1)) | |
sp_indices = tf.concat([indices_1, indices_2], axis=1) | |
sp_values = tf.reshape(weights, (-1,)) | |
sp_matrix = tf.SparseTensor(sp_indices, sp_values, dense_shape=(tf.shape(indices)[0], tf.shape(values)[0])) | |
out = tf.sparse.sparse_dense_matmul(sp_matrix, values) | |
grads = tape.gradient(out, [values, weights]) | |
return out, grads | |
def main(): | |
indices = np.random.randint(low=0, high=(512 * 512) - 1, size=(64*256, 128)).astype(np.int64) | |
values = np.random.rand(512 * 512, 256).astype(np.float32) | |
weights = np.random.rand(*indices.shape).astype(np.float32) | |
indices = tf.convert_to_tensor(indices) | |
values = tf.convert_to_tensor(values) | |
weights = tf.convert_to_tensor(weights) | |
print("Testing composite run...") | |
composite_out, composite_grads = composite_run(indices, values, weights) | |
print("Testing sparse run...") | |
sparse_out, sparse_grads = sparse_dense_matmul_run(indices, values, weights) | |
print('Running with composite op...') | |
start = perf_counter() | |
for _ in range(50): | |
out, grads = composite_run(indices, values, weights) | |
end = perf_counter() | |
print(f"50 runs took {end - start} seconds") | |
print('Running with sparse-dense matmul...') | |
start = perf_counter() | |
for _ in range(50): | |
out, grads = sparse_dense_matmul_run(indices, values, weights) | |
end = perf_counter() | |
print(f"50 runs took {end - start} seconds") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment