Last active
July 23, 2021 03:00
-
-
Save oO0oO0oO0o0o00/74dbcb352164348e5268203fdf95a04b to your computer and use it in GitHub Desktop.
Gradient Reversal Layer for tf.keras of TensorFlow 2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow import keras | |
class GradientReversal(keras.layers.Layer): | |
"""Flip the sign of gradient during training. | |
based on https://github.com/michetonu/gradient_reversal_keras_tf | |
ported to tf 2.x | |
""" | |
def __init__(self, λ=1, **kwargs): | |
super(GradientReversal, self).__init__(**kwargs) | |
self.λ = λ | |
@staticmethod | |
@tf.custom_gradient | |
def reverse_gradient(x, λ): | |
# @tf.custom_gradient suggested by Hoa's comment at | |
# https://stackoverflow.com/questions/60234725/how-to-use-gradient-override-map-with-tf-gradienttape-in-tf2-0 | |
return tf.identity(x), lambda dy: (-dy, None) | |
def call(self, x): | |
return self.reverse_gradient(x, self.λ) | |
def compute_mask(self, inputs, mask=None): | |
return mask | |
def compute_output_shape(self, input_shape): | |
return input_shape | |
def get_config(self): | |
return super(GradientReversal, self).get_config() | {'λ': self.λ} | |
########################################################### | |
# below is an example | |
import numpy as np | |
x = np.array([[1,1.5]], dtype=np.float32) | |
model1 = keras.Sequential([ | |
keras.layers.Dense(2, use_bias=False, kernel_initializer=keras.initializers.Constant(2.333), input_dim=2) | |
]) | |
model2 = keras.Sequential([model1, GradientReversal()]) | |
model2.summary() | |
# model.compile(...) | |
# model.fit(...) | |
# model.evaluate(...) | |
########################################################### | |
# verification | |
with tf.GradientTape() as tape1: | |
y1 = model1(x) | |
with tf.GradientTape() as tape2: | |
y2 = model2(x) | |
def print_this(tensors): | |
for tensor in tensors: | |
tensor = tensor.numpy() | |
print(f"{'x'.join([str(x) for x in tensor.shape])} (tf.{tensor.dtype.name}): {tensor.ravel().tolist()}") | |
print_this(tape1.gradient(y1, model1.trainable_variables)) | |
print_this(tape2.gradient(y2, model2.trainable_variables)) | |
# got: | |
# 2x2 (tf.float32): [0.0025975825265049934, 0.0025975825265049934, 0.00389637378975749, 0.00389637378975749] | |
# 2 (tf.float32): [0.0025975825265049934, 0.0025975825265049934] | |
# 2x2 (tf.float32): [-0.0025975825265049934, -0.0025975825265049934, -0.00389637378975749, -0.00389637378975749] | |
# 2 (tf.float32): [-0.0025975825265049934, -0.0025975825265049934] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment