Last active
May 8, 2024 21:33
-
-
Save stellarpower/24fd6b1cbd864a088ec2a5f3e8a9fb26 to your computer and use it in GitHub Desktop.
Tests for W&B Gradients Callback for Keras v3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Hide GPU to help with reproducibility between Keras versions | |
import os | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
# Copilot says these apparently reduces the chances or reproducibility issues; treat with caution | |
os.environ["TF_DETERMINISTIC_OPS"] = "1" | |
import keras, tensorflow as tf | |
## Source: https://keras.io/examples/keras_recipes/reproducibility_recipes/ ################## | |
# Set the seed using keras.utils.set_random_seed. This will set: | |
# 1) `numpy` seed | |
# 2) backend random seed | |
# 3) `python` random seed | |
keras.utils.set_random_seed(812) | |
# If using TensorFlow, this will make GPU ops as deterministic as possible, | |
# but it will affect the overall performance, so be mindful of that. | |
tf.config.experimental.enable_op_determinism() | |
############################################################################################# | |
# Copilot says these apparently reduces the chances or reproducibility issues; treat with caution | |
tf.config.threading.set_inter_op_parallelism_threads(1) | |
import tensorflow as tf | |
import numpy as np | |
import wandb | |
from wandb.keras import WandbCallback | |
from datetime import datetime | |
kerasVersion = keras.__version__ # version() available in v3. | |
launchTime = datetime.now().strftime("%m-%d-%Y_%H-%M-%S") | |
# Random numbers make the problem more realistic, but, are trickier to debug. | |
UseRandomNumbers = True | |
# name the run wiht the version and date and time | |
run = wandb.init( | |
project="TestKEras3Support", | |
name = f"keras_{ kerasVersion }-{ launchTime } - { 'random' if UseRandomNumbers else 'simple' }" | |
) | |
# Keep in W&B config | |
wandb.config.keras_version = kerasVersion | |
# Define a simple model | |
model = keras.models.Sequential([ | |
keras.layers.Dense( | |
1, | |
input_shape = (10,), | |
kernel_initializer = 'ones', | |
use_bias = False) | |
]) | |
model.compile( | |
optimizer = 'sgd', | |
loss = 'mean_squared_error', | |
jit_compile = False, # May help reproducibility between Keras versions | |
) | |
# Random numbers make the problem more realistic, but, whilst the tensors seem identical between Keras 2 and 3, | |
# the loss and gradients look different. But these made sense in testing based on the loss curves. | |
if UseRandomNumbers: | |
X_train = np.random.rand(1000, 10) | |
y_train = np.random.rand(1000, 1) | |
X_validate = np.random.rand(1000, 10) | |
y_validate = np.random.rand(1000, 1) | |
# But using this, the gradients are constant - probably because the loss goes to zero very rapidly with a trivial problem. | |
else: | |
# Generate some dummy data | |
X_train = np.linspace(0, 1, 10000).reshape(1000, 10) | |
X_validate = np.linspace(1, 2, 10000).reshape(1000, 10) | |
# Trivial output task to sum the inputs | |
y_train = np.sum(X_train, axis=1, keepdims=True) | |
y_validate = np.sum(X_validate, axis=1, keepdims=True) | |
# USe the old one to log gradients and weights easily: | |
gradientsCallback = wandb.integration.keras.WandbCallback( | |
# Have to have this it seems: | |
monitor = "val_loss", verbose = 0, mode = "auto", | |
# And then save weights and gradients to look for vanishing and exploding | |
log_weights = True, log_gradients = True, | |
training_data = (X_train, y_train), | |
validation_data = None, predictions = 1, | |
# And turning off things we don't want: | |
save_model = False, save_graph = False, | |
) | |
callbacks = [ | |
gradientsCallback, | |
] | |
# This hsould hopefully now be deterministic across runs. | |
model.fit( | |
X_train, y_train, | |
callbacks = callbacks, | |
validation_data = (X_validate, y_validate), | |
batch_size = 4, #1, | |
epochs = 16, | |
) | |
run.finish() | |
breakpoint() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment