Skip to content

Instantly share code, notes, and snippets.

@aniketbiprojit
Last active November 15, 2020 20:19
Show Gist options
  • Save aniketbiprojit/ce1611ec7f4a44c0b95c1d655b1ad2f1 to your computer and use it in GitHub Desktop.
Save aniketbiprojit/ce1611ec7f4a44c0b95c1d655b1ad2f1 to your computer and use it in GitHub Desktop.
Client Federated Average
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import tensorflow.keras as keras
# import tensorflow_federated as tff
from tensorflow_federated.python.tensorflow_libs import tensor_utils
from tensorflow_federated.python.learning.framework import optimizer_utils
import collections
model = keras.models.load_model('./model.h5')
model.compile('SGD')
weights = [[],[]]
weights_delta_arr = []
client_models = [keras.models.load_model(f'./keras_models/{model}') for model in os.listdir('keras_models')]
for client_model in client_models:
client_model.compile('SGD')
weights[0].append(client_model.weights[0])
weights[1].append(client_model.weights[1])
# print(tf.reduce_mean(client_model.weights[0]))
weights_delta = tf.nest.map_structure(tf.subtract,client_model.weights,model.weights)
weights_delta_arr.append(weights_delta)
updated_weights = [tf.reduce_mean(weights[0],0),tf.reduce_mean(weights[1],0)]
updated_weights_with_server = [tf.reduce_mean([updated_weights[0],model.weights[0]],0),tf.reduce_mean([updated_weights[1],model.weights[1]],0)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment