Last active
November 15, 2020 20:19
-
-
Save aniketbiprojit/ce1611ec7f4a44c0b95c1d655b1ad2f1 to your computer and use it in GitHub Desktop.
Client Federated Average
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
import tensorflow as tf | |
import tensorflow.keras as keras | |
# import tensorflow_federated as tff | |
from tensorflow_federated.python.tensorflow_libs import tensor_utils | |
from tensorflow_federated.python.learning.framework import optimizer_utils | |
import collections | |
model = keras.models.load_model('./model.h5') | |
model.compile('SGD') | |
weights = [[],[]] | |
weights_delta_arr = [] | |
client_models = [keras.models.load_model(f'./keras_models/{model}') for model in os.listdir('keras_models')] | |
for client_model in client_models: | |
client_model.compile('SGD') | |
weights[0].append(client_model.weights[0]) | |
weights[1].append(client_model.weights[1]) | |
# print(tf.reduce_mean(client_model.weights[0])) | |
weights_delta = tf.nest.map_structure(tf.subtract,client_model.weights,model.weights) | |
weights_delta_arr.append(weights_delta) | |
updated_weights = [tf.reduce_mean(weights[0],0),tf.reduce_mean(weights[1],0)] | |
updated_weights_with_server = [tf.reduce_mean([updated_weights[0],model.weights[0]],0),tf.reduce_mean([updated_weights[1],model.weights[1]],0)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment