Created
April 20, 2020 02:59
-
-
Save tamnguyenvan/e01a863e6d9f44c90b5e390693cd8ca9 to your computer and use it in GitHub Desktop.
Tensorflow slice assignment.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| # tf.enable_eager_execution() # uncomment if tf version < 2.0 | |
| def assign(x, slices, values): | |
| """Assign slicing-tensor to the given values. | |
| Args | |
| :x: A `Tensor` as input. | |
| :slices: A `Tensor`, `list` or `numpy array` indicates the region | |
| of the tensor that would be assigned. | |
| :values: A `Tensor` represents the assigned values. | |
| Returns | |
| A tensor where the specified region was assigned. | |
| """ | |
| input_shape = tf.shape(x) | |
| shape = [(end - beg) for beg, end in slices] | |
| ones = tf.ones(shape) | |
| padding = [(beg, dim - end) for dim, (beg, end) in zip(input_shape, slices)] | |
| mask = tf.cast(tf.pad(ones, padding), x.dtype) | |
| padded_values = tf.cast(tf.pad(values, padding), x.dtype) | |
| return mask * padded_values + (1. - mask) * x | |
| x = tf.random.uniform([3, 3], 0, 1) | |
| slices = [(0, 2), (0, 2)] | |
| values = tf.constant([[0, 0], [0, 0]]) | |
| x = assign(x, slices, values) | |
| print(x) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
As you know, Tensorflow hasn't provided a convenient way for sliced tensor assignment yet. I don't know why. It's a crucial feature. To get over that suck, I had to write this script. It's not perfect but anyway it's still helpful for me. It's heavily inspired by an idea I have found in a Tensorflow issue.