Skip to content

Instantly share code, notes, and snippets.

@ma7555
Created November 5, 2024 15:13
Show Gist options
  • Save ma7555/2f56163ce815e042e018f62e16f7c2c8 to your computer and use it in GitHub Desktop.
Save ma7555/2f56163ce815e042e018f62e16f7c2c8 to your computer and use it in GitHub Desktop.
Keras 3 XBM
import keras
class XBM(keras.losses.Loss):
def __init__(
self,
inner_loss,
memory_size=1024,
warmup_steps=0,
name="xbm_loss",
**kwargs,
):
super().__init__(name=name, **kwargs)
self.inner_loss = inner_loss
self.memory_size = memory_size
self.warmup_steps = warmup_steps
self.total_steps = 0
self.embeddings_memory = None
self.labels_memory = None
def call(self, y_true, y_pred):
if self.embeddings_memory is None:
embedding_dim = y_pred.shape[-1]
self.embeddings_memory = keras.ops.zeros((0, embedding_dim), dtype=y_pred.dtype)
self.labels_memory = keras.ops.zeros((0,), dtype=y_true.dtype)
y_true = keras.ops.squeeze(y_true, axis=-1)
self.total_steps += 1
if self.total_steps <= self.warmup_steps:
embeddings_concat = y_pred
labels_concat = y_true
else:
embeddings_concat = keras.ops.concatenate([y_pred, self.embeddings_memory], axis=0)
labels_concat = keras.ops.concatenate([y_true, self.labels_memory], axis=0)
loss = self.inner_loss.fn(
y_true,
y_pred,
ref_labels=labels_concat,
ref_embeddings=embeddings_concat,
**self.inner_loss._fn_kwargs,
)
embeddings_memory_new = keras.ops.concatenate([y_pred, self.embeddings_memory], axis=0)
labels_memory_new = keras.ops.concatenate([y_true, self.labels_memory], axis=0)
embeddings_memory_new = embeddings_memory_new[: self.memory_size]
labels_memory_new = labels_memory_new[: self.memory_size]
self.embeddings_memory = embeddings_memory_new
self.labels_memory = labels_memory_new
return loss
batch_size = 2048
warmup_steps = 2 * len(x_train) // batch_size
memory_size = min(batch_size*16, len(x_train))
circle_loss = keras.losses.Circle()
import keras
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
model = keras.Sequential()
model.add(keras.layers.InputLayer(shape=(32, 32, 3)))
model.add(keras.layers.Rescaling(1.0 / 255, offset=-1))
for i in range(3):
model.add(
keras.layers.Conv2D(
32, (3, 3), padding="valid", activation="relu", name=f"conv_{i}"
)
)
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(64, activation=None))
model.add(keras.layers.UnitNormalization())
xbm_loss = XBM(inner_loss=keras.losses.Circle(), memory_size=memory_size, warmup_steps=warmup_steps)
model.compile(optimizer='adam', loss=xbm_loss, run_eagerly=True)
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=5, batch_size=batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment