Skip to content

Instantly share code, notes, and snippets.

@sherjilozair
Last active May 3, 2018 03:08
Show Gist options
  • Save sherjilozair/cd1791574e201155e448034d1ed6a544 to your computer and use it in GitHub Desktop.
Save sherjilozair/cd1791574e201155e448034d1ed6a544 to your computer and use it in GitHub Desktop.
Spectral Normalization
import tensorflow as tf
import numpy as np
class SpecNorm(object):
def build(self, input_shape):
self.indim = input_shape[-1] * np.array(self.kernel_size).prod()
self.outdim = self.filters
self.u = self.add_variable(name='u', shape=(1, self.indim), trainable=False)
self.v = self.add_variable(name='v', shape=(self.outdim, 1), trainable=False)
super().build(input_shape)
self.norm = tf.matmul(tf.matmul(self.u, tf.reshape(self.kernel, [self.indim, self.outdim])), self.v)
self.oldkernel = self.kernel
self.kernel = self.kernel / self.norm
def call(self, inputs):
return super().call(inputs)
wv = tf.nn.l2_normalize(tf.matmul(tf.reshape(self.oldkernel, [self.indim, self.outdim]), self.v))
uw = tf.nn.l2_normalize(tf.matmul(self.u, tf.reshape(self.oldkernel, [self.indim, self.outdim])))
self.add_update([tf.assign(u, wv), tf.assign(v, uw)])
class SpecNormConv1D(SpecNorm, tf.layers.Conv1D):
pass
if __name__ == '__main__':
x = tf.placeholder(tf.float32, [128, 256, 32])
h = SpecNormConv1D(64, 3, padding='same')(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment