Last active December 30, 2024 05:15
Gradient Reversal Layer for tf.keras of TensorFlow 2
import tensorflow as tf
from tensorflow import keras
class GradientReversal(keras.layers.Layer):
"""Flip the sign of gradient during training.
based on
ported to tf 2.x
def __init__(self, λ=1, **kwargs):
super(GradientReversal, self).__init__(**kwargs)
self.λ = λ
def reverse_gradient(x, λ):
# @tf.custom_gradient suggested by Hoa's comment at
return tf.identity(x), lambda dy: (-dy, None)
def call(self, x):
return self.reverse_gradient(x, self.λ)
def compute_mask(self, inputs, mask=None):
return mask
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
return super(GradientReversal, self).get_config() | {'λ': self.λ}
# below is an example
import numpy as np
x = np.array([[1,1.5]], dtype=np.float32)
model1 = keras.Sequential([
keras.layers.Dense(2, use_bias=False, kernel_initializer=keras.initializers.Constant(2.333), input_dim=2)
model2 = keras.Sequential([model1, GradientReversal()])
# model.compile(...)
# model.evaluate(...)
# verification
with tf.GradientTape() as tape1:
y1 = model1(x)
with tf.GradientTape() as tape2:
y2 = model2(x)
def print_this(tensors):
for tensor in tensors:
tensor = tensor.numpy()
print(f"{'x'.join([str(x) for x in tensor.shape])} (tf.{}): {tensor.ravel().tolist()}")
print_this(tape1.gradient(y1, model1.trainable_variables))
print_this(tape2.gradient(y2, model2.trainable_variables))
# got:
# 2x2 (tf.float32): [0.0025975825265049934, 0.0025975825265049934, 0.00389637378975749, 0.00389637378975749]
# 2 (tf.float32): [0.0025975825265049934, 0.0025975825265049934]
# 2x2 (tf.float32): [-0.0025975825265049934, -0.0025975825265049934, -0.00389637378975749, -0.00389637378975749]
# 2 (tf.float32): [-0.0025975825265049934, -0.0025975825265049934]
