Created
January 16, 2021 17:09
-
-
Save ntakouris/30c12d52036fc00e7609259bab6051c7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@staticmethod | |
def gh_jacobian(y, x): | |
# x must be a vector | |
# y must be matrix with dimensions [batch_size, n], n > 1. if n == 1 use the _vec version | |
J = tf.map_fn(lambda m: tf.gradients(y[:,m:m+1], x)[0], tf.range(tf.shape(y)[-1]), tf.float32) | |
#J = tf.Print(J, [tf.shape(J)], "J shape = ", summarize=-1) | |
J = tf.transpose(tf.squeeze(J), perm = [1,0,2]) | |
return J | |
@staticmethod | |
def gh_jacobian_vec(y, x): | |
# x must be a vector | |
# y must be matrix with dimensions [batch_size, 1] | |
J = tf.map_fn(lambda m: tf.gradients(y[:,m:m+1], x)[0], tf.range(tf.shape(y)[-1]), tf.float32) | |
return tf.squeeze(J) | |
def build_inpaint_graph(self, log_generator_loss=False): | |
"""Builds the context and prior loss objective""" | |
with self.graph.as_default(): | |
self.masks = tf.placeholder(tf.float32, | |
[None] + self.image_shape, | |
name='mask') | |
self.images = tf.placeholder(tf.float32, | |
[None] + self.image_shape, | |
name='images') | |
##### setup ##### | |
# reduce with channel mean so that jacobians and matrix-matrix multiplications are defined | |
m = self.masks | |
#m = tf.Print(m, [tf.shape(m)], 'm shape = ') | |
y = tf.math.multiply(self.images, self.masks) | |
#y = tf.Print(y, [tf.shape(y)], 'y shape = ') | |
# from raymond's paper we know that: gl = 1 - d(g(z)) | |
do = tf.constant([1.0]) - self.gl | |
go = self.go # bs, 64, 64 | |
m_go = tf.math.multiply(go, m) | |
j_g_m = tf.reshape(self.gh_jacobian(tf.reshape( | |
m_go, [self.batch_size, 64* 64 * 3]), | |
self.gi), [self.batch_size, 64, 64, 3, 100]) | |
j_d = self.gh_jacobian_vec(do, self.gi) | |
#j_d = tf.Print(j_d, [tf.shape(j_d)], 'j_d shape = ', summarize=-1) | |
# # (J_g^t * M^t) @ (M * j_g) | |
a_right = j_g_m | |
a_left = tf.transpose(j_g_m, perm=[0, 2, 1, 3, 4]) # bs, 64r, 64c, 3ch, 100 -> bs, 64c, 64r, 3ch, 100 | |
self.a_left = a_left | |
self.a_right = a_right | |
a_left = tf.transpose(tf.reshape(a_left, [64, 64*64*3,100]), perm=[0,2,1]) | |
a_right = tf.reshape(a_right, [64, 64*64*3,100]) | |
# (bs, 64*64*3, 100)^t @ (bs, 64*64*3, 100) | |
a = a_left @ a_right | |
#a = tf.Print(a, [tf.shape(a)], 'a shape = ', summarize=-1) | |
self.a = a | |
a_inv = tf.linalg.inv(a) | |
#a_inv = tf.Print(a_inv, [tf.shape(a_inv)], 'a_inv shape = ', summarize=-1) | |
self.a_inv = a_inv | |
j_d_t = tf.expand_dims(j_d, axis=1) # bs, 1, 100 -> bs , 100, 1 | transpose essentially | |
#j_d_t = tf.Print(j_d_t, [tf.shape(j_d_t)], 'j_d_t shape = ', summarize=-1) | |
# # 1 / J_d^t * A^-1 * J_d | |
x = j_d_t @ a_inv | |
_lambd = 1 / tf.linalg.matvec(x, j_d) | |
#_lambd = tf.Print(_lambd, [tf.shape(_lambd)], 'lambda shape = ', summarize=-1) | |
self._lambd = _lambd | |
# # (J_g^t * M^t) (y - M * G(z)) - lambda * J_d | |
b_right = _lambd * j_d | |
b_left = a_left @ tf.reshape((y - m_go), [64, 64*64*3, 1]) | |
#b_left = tf.Print(b_left, [tf.shape(b_left)], 'b_left shape = ', summarize=-1) | |
self.b_left = b_left | |
self.b_right = b_right | |
b = tf.squeeze(b_left) - b_right | |
#b = tf.Print(b, [tf.shape(b)], 'b shape = ', summarize=-1) | |
self.b = b | |
self.dz = tf.linalg.matvec(a_inv, b) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment