Skip to content

Instantly share code, notes, and snippets.

@SirmaXX
Created July 17, 2024 10:38
Show Gist options
  • Save SirmaXX/66e79fc241afbe05529121c3351960c9 to your computer and use it in GitHub Desktop.
Save SirmaXX/66e79fc241afbe05529121c3351960c9 to your computer and use it in GitHub Desktop.
feterated learning
import argparse
import os
from flwr.client import ClientApp, NumPyClient
import tensorflow as tf
from flwr_datasets import FederatedDataset
# Parse arguments
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument("--partition-id", type=int, choices=[0, 1, 2], default=0, help="Partition of the dataset")
args, _ = parser.parse_known_args()
# Load model and data
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
# Download and partition dataset
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3})
partition = fds.load_partition(args.partition_id, "train")
partition.set_format("numpy")
partition = partition.train_test_split(test_size=0.2, seed=42)
x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"]
x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"]
# Define Flower client
class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model.get_weights()
def fit(self, parameters, config):
model.set_weights(parameters)
model.fit(x_train, y_train, epochs=1, batch_size=32)
return model.get_weights(), len(x_train), {}
def evaluate(self, parameters, config):
model.set_weights(parameters)
loss, accuracy = model.evaluate(x_test, y_test)
return loss, len(x_test), {"accuracy": accuracy}
def client_fn(cid: str):
return FlowerClient().to_client()
# Flower ClientApp
app = ClientApp(client_fn=client_fn)
if __name__ == "__main__":
from flwr.client import start_client
start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
from typing import List, Tuple
from flwr.server import ServerApp, ServerConfig
from flwr.server.strategy import FedAvg
from flwr.common import Metrics
# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]
return {"accuracy": sum(accuracies) / sum(examples)}
# Define strategy
strategy = FedAvg(evaluate_metrics_aggregation_fn=weighted_average)
# Define config
config = ServerConfig(num_rounds=3)
# Flower ServerApp
app = ServerApp(config=config, strategy=strategy)
if __name__ == "__main__":
from flwr.server import start_server
start_server(server_address="0.0.0.0:8080", config=config, strategy=strategy)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment