Skip to content

Instantly share code, notes, and snippets.

@gautamsreekumar
Last active March 11, 2019 08:18
Show Gist options
  • Save gautamsreekumar/81bf1ad1037ccfc5dfa8da1be28f2216 to your computer and use it in GitHub Desktop.
Save gautamsreekumar/81bf1ad1037ccfc5dfa8da1be28f2216 to your computer and use it in GitHub Desktop.
def initi(var_shape):
real_part = nprand.rand(var_shape[0], var_shape[1], var_shape[2], var_shape[3])
imag_part = nprand.rand(var_shape[0], var_shape[1], var_shape[2], var_shape[3])
return tf.constant_initializer(real_part + imag_part*1.0j)
def lrelu(tensor_in): # this is not leaky-relu
temp_real = tf.real(tensor_in)
temp_imag = tf.imag(tensor_in)
return tf.complex(temp_real, temp_imag)
learning_rate = 0.01
# training networks
# Input and target placeholders for training
inputs_train_64 = tf.placeholder(tf.float64, (None, tr_img_size, tr_img_size, channels), name="inputs_train_64")
inputs_train = tf.image.resize_images(inputs_train_64, size=[256, 256])
targets_train = tf.placeholder(tf.complex128, (None, 256, 256, 2), name="targets_train")
w1_shape = [2, 2, channels, 64]
w2_shape = [2, 2, 64, 128]
w3_shape = [2, 2, 128, 256]
w4_shape = [2, 2, 256, 512]
w1 = tf.get_variable(name='w1', shape=w1_shape, dtype=tf.float32, initializer=tf.truncated_normal_initializer) # [h x w x in_c x out_c]
w2 = tf.get_variable(name='w2', shape=w2_shape, dtype=tf.complex128, initializer=initi(w2_shape)) # [h x w x in_c x out_c]
w3 = tf.get_variable(name='w3', shape=w3_shape, dtype=tf.complex128, initializer=initi(w3_shape)) # [h x w x in_c x out_c]
w4 = tf.get_variable(name='w4', shape=w4_shape, dtype=tf.complex128, initializer=initi(w4_shape)) # [h x w x in_c x out_c]
conv1 = tf.nn.conv2d(input=inputs_train, filter=w1, strides=[1,2,2,1], padding='VALID')
conv1 = tf.complex(tf.cast(conv1, tf.float64), tf.cast(0.0, tf.float64))
conv2 = tf.nn.conv2d(input=lrelu(conv1), filter=w2, strides=[1,2,2,1], padding='VALID')
conv3 = tf.nn.conv2d(input=lrelu(conv2), filter=w3, strides=[1,2,2,1], padding='VALID')
encoded = tf.nn.conv2d(input=lrelu(conv3), filter=w4, strides=[1,2,2,1], padding='VALID')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment