Skip to content

Instantly share code, notes, and snippets.

@cbonesana
Created June 11, 2024 21:32
Show Gist options
  • Save cbonesana/19550cf5fad3a3f543360a29c7eb9f3b to your computer and use it in GitHub Desktop.
Save cbonesana/19550cf5fad3a3f543360a29c7eb9f3b to your computer and use it in GitHub Desktop.
Federated Learning is just a glorified chat

Federated Learning is just a glorified chat

Federated Learnig is just a fancy way to define a distributed algorithm. Many frameworks wants to provide a special implementation of this idea, by making it looks like a complicated stuff that requires specialized software and expertise. In the real world, what you need is just an endpoint of a web service and some line of code implementing the algorithm.

A good chat software is harder to implement than a complex federated learning algorithm...

What I'm looking at?

This gist contains an implementation of the BOFRF, the Boosting-Based Federated Random Forest algorithm, done with FastAPI. Nothing else.

The code can be tested using Docker. To run the example, use the docker-compose.yaml file below. This file requires to build a contaienr with the dependenceis (Dockerfile also provided below) and to mount local data.

This code is only a demonstration and it is not intended to be executed in production.

version: "3.7"
networks:
algo:
external: false
attachable: true
services:
coordinator:
image: federated:1.0
hostname: coordinator
environment:
- MODE=coordinator
- N_SITES=3
- HOSTNAME_COORDINATOR=coordinator
- HOSTNAME=coordinator
entrypoint: /bin/bash -c "uvicorn script:app --host 0.0.0.0 --port 80"
volumes:
- ./workdir/coordinator:/workdir
- ./script.py:/script.py
networks:
- algo
deploy:
placement:
constraints:
- node.role==manager
datanode_0:
image: federated:1.0
hostname: data0
environment:
- MODE=data
- N_SITES=3
- HOSTNAME_COORDINATOR=coordinator
- HOSTNAME=data0
entrypoint: /bin/bash -c "uvicorn script:app --host 0.0.0.0 --port 80"
volumes:
- ./workdir/0:/workdir
- ./workdir/0/data.csv:/data.csv
- ./script.py:/script.py
networks:
- algo
depends_on:
- coordinator
deploy:
placement:
constraints:
- node.role==manager
datanode_1:
image: federated:1.0
hostname: data1
environment:
- MODE=data
- N_SITES=3
- HOSTNAME_COORDINATOR=coordinator
- HOSTNAME=data1
entrypoint: /bin/bash -c "uvicorn script:app --host 0.0.0.0 --port 80"
volumes:
- ./workdir/1:/workdir
- ./workdir/1/data.csv:/data.csv
- ./script.py:/script.py
networks:
- algo
depends_on:
- coordinator
deploy:
placement:
constraints:
- node.role==manager
datanode_2:
image: federated:1.0
hostname: data2
environment:
- MODE=data
- N_SITES=3
- HOSTNAME_COORDINATOR=coordinator
- HOSTNAME=data2
entrypoint: /bin/bash -c "uvicorn script:app --host 0.0.0.0 --port 80"
volumes:
- ./workdir/2:/workdir
- ./workdir/2/data.csv:/data.csv
- ./script.py:/script.py
networks:
- algo
depends_on:
- coordinator
deploy:
placement:
constraints:
- node.role==manager
FROM python:3.12.3
RUN pip install \
pandas \
scikit-learn \
fastapi \
aiofiles
from enum import Enum
from typing import Any
from contextlib import asynccontextmanager
from dataclasses import dataclass
from fastapi import FastAPI, BackgroundTasks, HTTPException, Request
from pathlib import Path
from pydantic import BaseModel
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
from sklearn.tree import DecisionTreeClassifier
import aiofiles
import asyncio
import httpx
import json
import logging
import numpy as np
import os
import pandas as pd
import pickle
import time
import uuid
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s,%(msecs)03d %(levelname)8s %(name)16s:%(lineno)-4s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
LOGGER = logging.getLogger(__name__)
# parameters
MODE: str = os.environ.get("MODE", "data") # or "coordinator"
HOSTNAME: str = os.environ.get("HOSTNAME", "")
HOSTNAME_COORDINATOR: str = os.environ.get("HOSTNAME_COORDINATOR", "")
N_SITES: int = int(os.environ.get("N_NODES", 3))
N_TREES: int = int(os.environ.get("N_TREES", 5))
TAU: float = float(os.environ.get("TAU", 0.2))
PATH_DATA: Path = Path(os.environ.get("PATH_DATA", "/data.csv"))
WORKING_DIR: Path = Path(os.environ.get("WORKING_DIR", "/workdir"))
# functions
def mcc(cm: np.ndarray) -> float:
tn, fp, fn, tp = cm.ravel()
return ((tp * tn) - (fp * fn)) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
def dump(path: Path, o: Any | bytes) -> None:
with open(path, "wb") as f:
content: bytes = o if isinstance(o, bytes) else pickle.dumps(o)
f.write(content)
def load(path: Path) -> Any:
with open(path, "rb") as f:
return pickle.load(f)
async def adump(path: Path, o: Any | bytes) -> None:
async with aiofiles.open(path, "wb") as f:
content: bytes = o if isinstance(o, bytes) else pickle.dumps(o)
await f.write(content)
async def aload(path: Path) -> None:
async with aiofiles.open(path, "rb") as f:
content: bytes = await f.read()
return pickle.loads(content)
# exchange data definition
class WorkerId(BaseModel):
i: int
id: str
url: str
class WorkerList(BaseModel):
coordinator: WorkerId
workers: list[WorkerId]
# execution classes
@dataclass
class Model:
i: str
trees: list[DecisionTreeClassifier]
alpha: np.ndarray
def predict_proba(self, X) -> np.ndarray:
probs = list()
for tree, alpha in zip(self.trees, self.alpha):
p = tree.predict_proba(X)
probs.append(p * alpha)
a_probs = np.array(probs)
return a_probs.sum(axis=0) / a_probs.sum(axis=(0, 2)).reshape((X.shape[0], 1))
def predict(self, X) -> np.ndarray:
return self.predict_proba(X).argmax(axis=1)
def __repr__(self) -> str:
return f"Model#{self.i}(trees={len(self.trees)})"
class Worker(FastAPI):
def __init__(self, prefix: str, lifespan, **extra: Any):
super().__init__(lifespan=lifespan, **extra)
# self assigned identificatio number
self.id: str = f"{prefix}-{str(uuid.uuid4())[:8]}"
# unique numeric id inside the federation
self.i: int = -1
# barrier counters
self.trees_received: int = 0
self.matrix_received: int = 0
self.models_received: int = 0
# infos to contact the aggregator node
self.coordinator: WorkerId
# infos to contact all other workers in the networks
self.workers: dict[str, WorkerId] = dict()
# data loaded
self.X: np.ndarray
self.Y: np.ndarray
# model definitio
self.R: RandomForestClassifier = RandomForestClassifier(N_TREES)
self.C: np.ndarray = np.zeros((N_SITES, N_TREES, 2, 2)) # site, tree, matrix
def load(self, path: Path) -> None:
"""Load data from disk from the given path."""
df = pd.read_csv(path)
self.X: np.ndarray = df.drop("y", axis=1).to_numpy()
self.Y: np.ndarray = df["y"].to_numpy()
def train(self) -> None:
"""Step 1: Train a Random Forest model at each site."""
self.R.fit(self.X, self.Y)
for k, tree in enumerate(self.share_trees()):
y_pred = tree.predict(self.X)
self.C[0, k] = confusion_matrix(self.Y, y_pred)
def share_trees(self) -> list[DecisionTreeClassifier]:
"""Step 2: Share decision trees with every other site."""
return self.R.estimators_
def evaluate(self, trees: list[DecisionTreeClassifier]) -> np.ndarray:
"""Step 3: Run each decision tree retrieved from other sites."""
C = np.zeros((N_TREES, 2, 2))
for k, tree in enumerate(trees):
y_pred = tree.predict(self.X)
C[k] = confusion_matrix(self.Y, y_pred)
return C
def update(self, i: int, cm: np.ndarray) -> None:
"""Step 3: Send confusion matrices."""
self.C[i] = cm
def merge_cm(self) -> np.ndarray:
"""Step 4: Merge confusion matrices and calculate weight."""
alpha: np.ndarray = np.zeros((N_TREES, 1))
c_q = self.C.sum(axis=0)
for k in range(N_TREES):
C_in = c_q[k]
w = mcc(C_in)
if w > TAU:
alpha[k] = w
return alpha
def get_local_model(self) -> Model:
"""Step 5: Build the local ensemble model."""
return Model(
self.id,
self.share_trees(),
self.merge_cm(),
)
def aggregate(self, local_models: list[Model]) -> Model:
"""Step 6: Build the global federated ensemble model"""
alpha: np.ndarray = np.ndarray((N_TREES * N_SITES, 1))
trees: list[DecisionTreeClassifier] = list()
for j, local_model in enumerate(local_models):
for i, tree in enumerate(local_model.trees):
trees.append(tree)
alpha[j * N_TREES + i] = local_model.alpha[i]
return Model("global", trees, alpha)
def __repr__(self) -> str:
return f"Worker#{self.id}({self.i})"
def __eq__(self, other: object) -> bool:
if isinstance(other, Worker):
return self.id == other.id
return False
def identify(self) -> WorkerId:
return WorkerId(i=self.i, id=self.id, url=f"http://{HOSTNAME}")
async def startup(worker: Worker):
"""Startup procedure. If it is the coordinator node, then nothing happens.
When it is a data node, it will try to register itself."""
LOGGER.info(f"startup procedure initiated")
if MODE != "data":
LOGGER.info("node is a coordinator, startup completed")
return
wid: WorkerId = worker.identify()
LOGGER.info(f"node is a worker {wid.id} proceed with registration sequence")
flag: bool = True
client: httpx.AsyncClient = httpx.AsyncClient()
while flag:
LOGGER.info(f"contacting coordinator at {HOSTNAME_COORDINATOR}...")
try:
r: httpx.Response = await client.post(
f"http://{HOSTNAME_COORDINATOR}/ready/{worker.id}",
content=wid.model_dump_json(),
)
LOGGER.info(f"ready message: {r.status_code}")
r.raise_for_status()
worker.coordinator = WorkerId(**json.loads(r.content))
LOGGER.info(f"coordinator info: {worker.coordinator}")
flag = False
except Exception as e:
LOGGER.warning(f"coordinator node not available ({e}) waiting 1 second")
await asyncio.sleep(1)
await client.aclose()
LOGGER.info("startup procedure completed")
async def shutdown(worker: Worker):
LOGGER.info(f"comencing shutdown procedure for node {worker}")
@asynccontextmanager
async def lifespan(worker: Worker):
"""This is just to control startup and shutdown procedures for FastAPI."""
await startup(worker)
yield
await shutdown(worker)
# FastAPI object dfinition
worker: Worker = Worker(MODE, lifespan=lifespan)
class Command(str, Enum):
"""All commands are steps in the algorithm"""
ready = "ready" # Step0: Register the worker
start = "start" # Step 1: Train a Random Forest model at each site
trees = "trees" # Step 2: Share decision trees with every other site.
matrix = "matrix" # Step 3: Send confusion matrices.
done = "done" # Step 6: Build the global federated ensemble model
@worker.post("/{command}/{id}")
@worker.get("/")
async def exec(
request: Request,
background: BackgroundTasks,
id: str,
command: Command | None = None,
) -> WorkerId:
"""This is the main and only endpoint to interact with the workers. This
endpoint imitates a chat between all the component of the federated network."""
LOGGER.info(f"command: {command} id: {id}")
content: bytes = await request.body()
if command is Command.ready:
# a new worker is registering to the coordinator
LOGGER.info(f"new worker {id} is ready")
worker.workers[id] = WorkerId(**json.loads(content))
if len(worker.workers) == N_SITES:
background.add_task(task_start_workers, worker=worker)
elif command is Command.start:
# received start command
data = WorkerList(**json.loads(content))
worker.coordinator = data.coordinator
for w in data.workers:
worker.workers[w.id] = w
if w.id == worker.id:
worker.i = w.i
LOGGER.info(f"start training with coordinator={worker.coordinator.id} and {len(worker.workers)} worker(s)")
background.add_task(task_train_local_model, worker=worker)
elif command is Command.trees:
# received new trees from another worker
w: WorkerId = worker.workers[id]
LOGGER.info(f"received trees from {w.id} ({w.i})")
await adump(WORKING_DIR / f"{w.id}.{w.i}.trees", content)
worker.trees_received += 1
LOGGER.info(f"total trees: {worker.trees_received}/{N_SITES-1}")
if worker.trees_received == N_SITES - 1:
LOGGER.info(f"received {worker.trees_received} trees, scheduling task")
background.add_task(task_local_evaluation, worker=worker)
elif command is Command.matrix:
# received a new confusion matrix from another worker
w: WorkerId = worker.workers[id]
LOGGER.info(f"received matrix from {w.id} ({w.i})")
cm: np.ndarray = pickle.loads(content)
worker.update(w.i, cm)
await adump(WORKING_DIR / f"{w.id}.{w.i}.cm", cm)
worker.matrix_received += 1
LOGGER.info(f"total confusion matrix: {worker.matrix_received}/{N_SITES-1}")
if worker.matrix_received == N_SITES - 1:
LOGGER.info(f"received {worker.matrix_received} matrices, scheduling task")
background.add_task(task_build_local_model, worker=worker)
elif command is Command.done:
# received a local model from a worker
w: WorkerId = worker.workers[id]
LOGGER.info(f"received local model from {w.id} ({w.i})")
await adump(WORKING_DIR / f"{w.id}.{w.i}.model", content)
worker.models_received += 1
if worker.models_received == N_SITES:
LOGGER.info(f"received {worker.models_received} local models, scheduling aggregation")
background.add_task(task_aggregate, worker=worker)
return worker.identify()
def task_start_workers(worker: Worker):
"""Background task of the coordinator to launch the work on all the workers with data."""
LOGGER.info("all workers ready, sending start command")
# assign numeric id
worker_list: list[WorkerId] = list()
for i, w in enumerate(worker.workers.values()):
w.i = i
worker_list.append(w)
wl: WorkerList = WorkerList(
coordinator=worker.identify(),
workers=worker_list,
)
for w in worker_list:
r: httpx.Response = httpx.post(
f"{w.url}/start/{worker.id}",
content=wl.model_dump_json(),
)
r.raise_for_status()
def task_train_local_model(worker: Worker):
"""Background task of a worker that load local data and train the local model."""
LOGGER.info("loading data")
worker.load(Path(PATH_DATA))
LOGGER.info("training...")
worker.train()
trees: list[DecisionTreeClassifier] = worker.share_trees()
dump(WORKING_DIR / f"{worker.id}.{worker.i}.trees", trees)
LOGGER.info("saved local trees to disk")
for _, w in worker.workers.items():
# distribute the local trees to other workers
if w.id == worker.id:
continue
LOGGER.info(f"sending trees to {w.id}")
r: httpx.Response = httpx.post(
f"{w.url}/trees/{worker.id}",
content=pickle.dumps(trees),
)
r.raise_for_status()
LOGGER.info("training task completed, local trees sent to other workers")
def task_local_evaluation(worker: Worker):
"""Background task for a worker that will evaluate other decison trees on local data."""
while not os.path.exists(WORKING_DIR / f"{worker.id}.{worker.i}.trees"):
LOGGER.warn(f"waiting for local training of {worker.id}.{worker.i}.trees (1 second)")
time.sleep(1.0)
LOGGER.warn("all trees available, starting confusion matirces creation")
for _, w in worker.workers.items():
# evaluates the decision trees on local data, then share the obtained confusion matrix
LOGGER.info(f"evaluating matrix for {w.id} ({w.i})")
trees: list[DecisionTreeClassifier] = load(WORKING_DIR / f"{w.id}.{w.i}.trees")
cm = worker.evaluate(trees)
worker.update(w.i, cm)
dump(WORKING_DIR / f"{w.id}.{w.i}.cm", cm)
if w.id == worker.id:
LOGGER.info(f"local matrix {w.id} ({w.i}) available")
else:
LOGGER.info(f"sending local matrix to {w.id} ({w.i})")
# send to other workers
httpx.post(
f"{w.url}/matrix/{worker.id}",
content=pickle.dumps(cm),
)
LOGGER.info("matrices sent to all workers")
def task_build_local_model(worker: Worker):
"""Background task of the worker used to build a local model from."""
while not os.path.exists(WORKING_DIR / f"{worker.id}.{worker.i}.cm"):
LOGGER.warn("local matrix not available, waiting 1 second")
time.sleep(1.0)
LOGGER.info("creating local model")
local_model: Model = worker.get_local_model()
content = pickle.dumps(local_model)
dump(WORKING_DIR / f"{worker.id}.{worker.i}.model", content)
LOGGER.info("sending local model to coordinator")
httpx.post(
f"{worker.coordinator.url}/done/{worker.id}",
content=content,
)
LOGGER.info("local work done")
def task_aggregate(worker: Worker):
"""Background task of the coordinator used to build the global model."""
LOGGER.info("start global model aggregation")
local_models: list[Model] = list()
for _, w in worker.workers.items():
local_model: Model = load(WORKING_DIR / f"{w.id}.{w.i}.model")
local_models.append(local_model)
global_model = worker.aggregate(local_models)
dump(WORKING_DIR / "global.model", global_model)
LOGGER.info("global model aggregation completed")
# just to maintain traditions :)
app = worker
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment