Skip to content

Instantly share code, notes, and snippets.

@oO0oO0oO0o0o00
Last active July 23, 2021 03:00
Show Gist options
  • Save oO0oO0oO0o0o00/74dbcb352164348e5268203fdf95a04b to your computer and use it in GitHub Desktop.
Save oO0oO0oO0o0o00/74dbcb352164348e5268203fdf95a04b to your computer and use it in GitHub Desktop.
Gradient Reversal Layer for tf.keras of TensorFlow 2
import tensorflow as tf
from tensorflow import keras
class GradientReversal(keras.layers.Layer):
"""Flip the sign of gradient during training.
based on https://github.com/michetonu/gradient_reversal_keras_tf
ported to tf 2.x
"""
def __init__(self, λ=1, **kwargs):
super(GradientReversal, self).__init__(**kwargs)
self.λ = λ
@staticmethod
@tf.custom_gradient
def reverse_gradient(x, λ):
# @tf.custom_gradient suggested by Hoa's comment at
# https://stackoverflow.com/questions/60234725/how-to-use-gradient-override-map-with-tf-gradienttape-in-tf2-0
return tf.identity(x), lambda dy: (-dy, None)
def call(self, x):
return self.reverse_gradient(x, self.λ)
def compute_mask(self, inputs, mask=None):
return mask
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
return super(GradientReversal, self).get_config() | {'λ': self.λ}
###########################################################
# below is an example
import numpy as np
x = np.array([[1,1.5]], dtype=np.float32)
model1 = keras.Sequential([
keras.layers.Dense(2, use_bias=False, kernel_initializer=keras.initializers.Constant(2.333), input_dim=2)
])
model2 = keras.Sequential([model1, GradientReversal()])
model2.summary()
# model.compile(...)
# model.fit(...)
# model.evaluate(...)
###########################################################
# verification
with tf.GradientTape() as tape1:
y1 = model1(x)
with tf.GradientTape() as tape2:
y2 = model2(x)
def print_this(tensors):
for tensor in tensors:
tensor = tensor.numpy()
print(f"{'x'.join([str(x) for x in tensor.shape])} (tf.{tensor.dtype.name}): {tensor.ravel().tolist()}")
print_this(tape1.gradient(y1, model1.trainable_variables))
print_this(tape2.gradient(y2, model2.trainable_variables))
# got:
# 2x2 (tf.float32): [0.0025975825265049934, 0.0025975825265049934, 0.00389637378975749, 0.00389637378975749]
# 2 (tf.float32): [0.0025975825265049934, 0.0025975825265049934]
# 2x2 (tf.float32): [-0.0025975825265049934, -0.0025975825265049934, -0.00389637378975749, -0.00389637378975749]
# 2 (tf.float32): [-0.0025975825265049934, -0.0025975825265049934]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment