Skip to content

Instantly share code, notes, and snippets.

@bodokaiser
Last active September 24, 2017 09:29
Show Gist options
  • Save bodokaiser/a7f8d655c3f503ca8db732f9e0f8086b to your computer and use it in GitHub Desktop.
Save bodokaiser/a7f8d655c3f503ca8db732f9e0f8086b to your computer and use it in GitHub Desktop.
How to construct a 3d grid in tensorflow (equivalent functions exist in numpy).
import tensorflow as tf
# volume shape: H x W x D
# cartesian notation (x,y):
# (0,0) (1,0) (2,0) ... (W-1,0)
# (0,1) (1,1) (2,1) ... (W-1,1)
# index notation (i,j):
# (0,0) (0,1) (0,2) ... (0,W-1)
# (1,0) (1,1) (1,2) ... (0,W-1)
# point in in index notation (i,j,k):
point = [60, 100, 20]
# shape of our grid surrounding the point
shape = [32, 32, 32]
# i=44,...,75
rows = tf.range(index[0]-shape[0]//2, index[0]+shape[0]//2, dtype=tf.int32)
# j=84,...,115
cols = tf.range(index[1]-shape[1]//2, index[1]+shape[1]//2, dtype=tf.int32)
# k=4,...,36
slices = tf.range(index[2]-shape[2]//2, index[2]+shape[2]//2, dtype=tf.int32)
# constructs the grid for us, be careful about the order!
k,i,j = tf.meshgrid(slices, cols, rows, indexing='ij')
# k consists of:
# [[4,...,4],...,[4,...,4]]
# ...
# [[36,...,36],...,[36,...,36]]
# i constist of
# [[84,...,84],[85,...,85],...,[115,...115]]
# ...
# [[84,...,84],[85,...,85],...,[115,...115]]
# j consists of
# [[44,45,...,75],[44,45,...,75],...,[44,45,..,75]]
# ...
# [[44,45,...,75],[44,45,...,75],...,[44,45,..,75]]
# reshaping them to a list and stacking them yields indices of shape [32^3, 3] with values
# [[84,44,4],[84,45,4],...,[84,75,4],...,[115,75,4],...,[115,75,36]]
indices = tf.stack([
tf.reshape(i, [-1]),
tf.reshape(j, [-1]),
tf.reshape(k, [-1]),
])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment