|
from enum import Enum |
|
from typing import Any |
|
|
|
from contextlib import asynccontextmanager |
|
from dataclasses import dataclass |
|
from fastapi import FastAPI, BackgroundTasks, HTTPException, Request |
|
from pathlib import Path |
|
from pydantic import BaseModel |
|
from sklearn.ensemble import RandomForestClassifier |
|
from sklearn.metrics import confusion_matrix |
|
from sklearn.tree import DecisionTreeClassifier |
|
|
|
import aiofiles |
|
import asyncio |
|
import httpx |
|
import json |
|
import logging |
|
import numpy as np |
|
import os |
|
import pandas as pd |
|
import pickle |
|
import time |
|
import uuid |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s,%(msecs)03d %(levelname)8s %(name)16s:%(lineno)-4s %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
|
|
LOGGER = logging.getLogger(__name__) |
|
|
|
# parameters |
|
MODE: str = os.environ.get("MODE", "data") # or "coordinator" |
|
|
|
HOSTNAME: str = os.environ.get("HOSTNAME", "") |
|
HOSTNAME_COORDINATOR: str = os.environ.get("HOSTNAME_COORDINATOR", "") |
|
|
|
N_SITES: int = int(os.environ.get("N_NODES", 3)) |
|
N_TREES: int = int(os.environ.get("N_TREES", 5)) |
|
TAU: float = float(os.environ.get("TAU", 0.2)) |
|
|
|
PATH_DATA: Path = Path(os.environ.get("PATH_DATA", "/data.csv")) |
|
WORKING_DIR: Path = Path(os.environ.get("WORKING_DIR", "/workdir")) |
|
|
|
|
|
# functions |
|
def mcc(cm: np.ndarray) -> float: |
|
tn, fp, fn, tp = cm.ravel() |
|
return ((tp * tn) - (fp * fn)) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) |
|
|
|
|
|
def dump(path: Path, o: Any | bytes) -> None: |
|
with open(path, "wb") as f: |
|
content: bytes = o if isinstance(o, bytes) else pickle.dumps(o) |
|
f.write(content) |
|
|
|
|
|
def load(path: Path) -> Any: |
|
with open(path, "rb") as f: |
|
return pickle.load(f) |
|
|
|
|
|
async def adump(path: Path, o: Any | bytes) -> None: |
|
async with aiofiles.open(path, "wb") as f: |
|
content: bytes = o if isinstance(o, bytes) else pickle.dumps(o) |
|
await f.write(content) |
|
|
|
|
|
async def aload(path: Path) -> None: |
|
async with aiofiles.open(path, "rb") as f: |
|
content: bytes = await f.read() |
|
return pickle.loads(content) |
|
|
|
|
|
# exchange data definition |
|
class WorkerId(BaseModel): |
|
i: int |
|
id: str |
|
url: str |
|
|
|
|
|
class WorkerList(BaseModel): |
|
coordinator: WorkerId |
|
workers: list[WorkerId] |
|
|
|
|
|
# execution classes |
|
|
|
|
|
@dataclass |
|
class Model: |
|
i: str |
|
trees: list[DecisionTreeClassifier] |
|
alpha: np.ndarray |
|
|
|
def predict_proba(self, X) -> np.ndarray: |
|
probs = list() |
|
for tree, alpha in zip(self.trees, self.alpha): |
|
p = tree.predict_proba(X) |
|
probs.append(p * alpha) |
|
|
|
a_probs = np.array(probs) |
|
return a_probs.sum(axis=0) / a_probs.sum(axis=(0, 2)).reshape((X.shape[0], 1)) |
|
|
|
def predict(self, X) -> np.ndarray: |
|
return self.predict_proba(X).argmax(axis=1) |
|
|
|
def __repr__(self) -> str: |
|
return f"Model#{self.i}(trees={len(self.trees)})" |
|
|
|
|
|
class Worker(FastAPI): |
|
def __init__(self, prefix: str, lifespan, **extra: Any): |
|
super().__init__(lifespan=lifespan, **extra) |
|
|
|
# self assigned identificatio number |
|
self.id: str = f"{prefix}-{str(uuid.uuid4())[:8]}" |
|
# unique numeric id inside the federation |
|
self.i: int = -1 |
|
|
|
# barrier counters |
|
self.trees_received: int = 0 |
|
self.matrix_received: int = 0 |
|
self.models_received: int = 0 |
|
|
|
# infos to contact the aggregator node |
|
self.coordinator: WorkerId |
|
# infos to contact all other workers in the networks |
|
self.workers: dict[str, WorkerId] = dict() |
|
|
|
# data loaded |
|
self.X: np.ndarray |
|
self.Y: np.ndarray |
|
|
|
# model definitio |
|
self.R: RandomForestClassifier = RandomForestClassifier(N_TREES) |
|
self.C: np.ndarray = np.zeros((N_SITES, N_TREES, 2, 2)) # site, tree, matrix |
|
|
|
def load(self, path: Path) -> None: |
|
"""Load data from disk from the given path.""" |
|
df = pd.read_csv(path) |
|
|
|
self.X: np.ndarray = df.drop("y", axis=1).to_numpy() |
|
self.Y: np.ndarray = df["y"].to_numpy() |
|
|
|
def train(self) -> None: |
|
"""Step 1: Train a Random Forest model at each site.""" |
|
self.R.fit(self.X, self.Y) |
|
|
|
for k, tree in enumerate(self.share_trees()): |
|
y_pred = tree.predict(self.X) |
|
self.C[0, k] = confusion_matrix(self.Y, y_pred) |
|
|
|
def share_trees(self) -> list[DecisionTreeClassifier]: |
|
"""Step 2: Share decision trees with every other site.""" |
|
return self.R.estimators_ |
|
|
|
def evaluate(self, trees: list[DecisionTreeClassifier]) -> np.ndarray: |
|
"""Step 3: Run each decision tree retrieved from other sites.""" |
|
C = np.zeros((N_TREES, 2, 2)) |
|
|
|
for k, tree in enumerate(trees): |
|
y_pred = tree.predict(self.X) |
|
C[k] = confusion_matrix(self.Y, y_pred) |
|
|
|
return C |
|
|
|
def update(self, i: int, cm: np.ndarray) -> None: |
|
"""Step 3: Send confusion matrices.""" |
|
self.C[i] = cm |
|
|
|
def merge_cm(self) -> np.ndarray: |
|
"""Step 4: Merge confusion matrices and calculate weight.""" |
|
alpha: np.ndarray = np.zeros((N_TREES, 1)) |
|
|
|
c_q = self.C.sum(axis=0) |
|
for k in range(N_TREES): |
|
C_in = c_q[k] |
|
w = mcc(C_in) |
|
|
|
if w > TAU: |
|
alpha[k] = w |
|
|
|
return alpha |
|
|
|
def get_local_model(self) -> Model: |
|
"""Step 5: Build the local ensemble model.""" |
|
return Model( |
|
self.id, |
|
self.share_trees(), |
|
self.merge_cm(), |
|
) |
|
|
|
def aggregate(self, local_models: list[Model]) -> Model: |
|
"""Step 6: Build the global federated ensemble model""" |
|
|
|
alpha: np.ndarray = np.ndarray((N_TREES * N_SITES, 1)) |
|
trees: list[DecisionTreeClassifier] = list() |
|
|
|
for j, local_model in enumerate(local_models): |
|
for i, tree in enumerate(local_model.trees): |
|
trees.append(tree) |
|
alpha[j * N_TREES + i] = local_model.alpha[i] |
|
|
|
return Model("global", trees, alpha) |
|
|
|
def __repr__(self) -> str: |
|
return f"Worker#{self.id}({self.i})" |
|
|
|
def __eq__(self, other: object) -> bool: |
|
if isinstance(other, Worker): |
|
return self.id == other.id |
|
return False |
|
|
|
def identify(self) -> WorkerId: |
|
return WorkerId(i=self.i, id=self.id, url=f"http://{HOSTNAME}") |
|
|
|
|
|
async def startup(worker: Worker): |
|
"""Startup procedure. If it is the coordinator node, then nothing happens. |
|
When it is a data node, it will try to register itself.""" |
|
LOGGER.info(f"startup procedure initiated") |
|
|
|
if MODE != "data": |
|
LOGGER.info("node is a coordinator, startup completed") |
|
return |
|
|
|
wid: WorkerId = worker.identify() |
|
|
|
LOGGER.info(f"node is a worker {wid.id} proceed with registration sequence") |
|
|
|
flag: bool = True |
|
client: httpx.AsyncClient = httpx.AsyncClient() |
|
|
|
while flag: |
|
LOGGER.info(f"contacting coordinator at {HOSTNAME_COORDINATOR}...") |
|
|
|
try: |
|
r: httpx.Response = await client.post( |
|
f"http://{HOSTNAME_COORDINATOR}/ready/{worker.id}", |
|
content=wid.model_dump_json(), |
|
) |
|
|
|
LOGGER.info(f"ready message: {r.status_code}") |
|
|
|
r.raise_for_status() |
|
|
|
worker.coordinator = WorkerId(**json.loads(r.content)) |
|
LOGGER.info(f"coordinator info: {worker.coordinator}") |
|
flag = False |
|
|
|
except Exception as e: |
|
LOGGER.warning(f"coordinator node not available ({e}) waiting 1 second") |
|
await asyncio.sleep(1) |
|
|
|
await client.aclose() |
|
|
|
LOGGER.info("startup procedure completed") |
|
|
|
|
|
async def shutdown(worker: Worker): |
|
LOGGER.info(f"comencing shutdown procedure for node {worker}") |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(worker: Worker): |
|
"""This is just to control startup and shutdown procedures for FastAPI.""" |
|
await startup(worker) |
|
yield |
|
await shutdown(worker) |
|
|
|
|
|
# FastAPI object dfinition |
|
worker: Worker = Worker(MODE, lifespan=lifespan) |
|
|
|
|
|
class Command(str, Enum): |
|
"""All commands are steps in the algorithm""" |
|
|
|
ready = "ready" # Step0: Register the worker |
|
start = "start" # Step 1: Train a Random Forest model at each site |
|
trees = "trees" # Step 2: Share decision trees with every other site. |
|
matrix = "matrix" # Step 3: Send confusion matrices. |
|
done = "done" # Step 6: Build the global federated ensemble model |
|
|
|
|
|
@worker.post("/{command}/{id}") |
|
@worker.get("/") |
|
async def exec( |
|
request: Request, |
|
background: BackgroundTasks, |
|
id: str, |
|
command: Command | None = None, |
|
) -> WorkerId: |
|
"""This is the main and only endpoint to interact with the workers. This |
|
endpoint imitates a chat between all the component of the federated network.""" |
|
|
|
LOGGER.info(f"command: {command} id: {id}") |
|
|
|
content: bytes = await request.body() |
|
|
|
if command is Command.ready: |
|
# a new worker is registering to the coordinator |
|
LOGGER.info(f"new worker {id} is ready") |
|
|
|
worker.workers[id] = WorkerId(**json.loads(content)) |
|
|
|
if len(worker.workers) == N_SITES: |
|
background.add_task(task_start_workers, worker=worker) |
|
|
|
elif command is Command.start: |
|
# received start command |
|
data = WorkerList(**json.loads(content)) |
|
|
|
worker.coordinator = data.coordinator |
|
for w in data.workers: |
|
worker.workers[w.id] = w |
|
if w.id == worker.id: |
|
worker.i = w.i |
|
|
|
LOGGER.info(f"start training with coordinator={worker.coordinator.id} and {len(worker.workers)} worker(s)") |
|
|
|
background.add_task(task_train_local_model, worker=worker) |
|
|
|
elif command is Command.trees: |
|
# received new trees from another worker |
|
w: WorkerId = worker.workers[id] |
|
|
|
LOGGER.info(f"received trees from {w.id} ({w.i})") |
|
|
|
await adump(WORKING_DIR / f"{w.id}.{w.i}.trees", content) |
|
worker.trees_received += 1 |
|
|
|
LOGGER.info(f"total trees: {worker.trees_received}/{N_SITES-1}") |
|
|
|
if worker.trees_received == N_SITES - 1: |
|
LOGGER.info(f"received {worker.trees_received} trees, scheduling task") |
|
background.add_task(task_local_evaluation, worker=worker) |
|
|
|
elif command is Command.matrix: |
|
# received a new confusion matrix from another worker |
|
w: WorkerId = worker.workers[id] |
|
|
|
LOGGER.info(f"received matrix from {w.id} ({w.i})") |
|
|
|
cm: np.ndarray = pickle.loads(content) |
|
|
|
worker.update(w.i, cm) |
|
|
|
await adump(WORKING_DIR / f"{w.id}.{w.i}.cm", cm) |
|
worker.matrix_received += 1 |
|
|
|
LOGGER.info(f"total confusion matrix: {worker.matrix_received}/{N_SITES-1}") |
|
|
|
if worker.matrix_received == N_SITES - 1: |
|
LOGGER.info(f"received {worker.matrix_received} matrices, scheduling task") |
|
background.add_task(task_build_local_model, worker=worker) |
|
|
|
elif command is Command.done: |
|
# received a local model from a worker |
|
w: WorkerId = worker.workers[id] |
|
|
|
LOGGER.info(f"received local model from {w.id} ({w.i})") |
|
|
|
await adump(WORKING_DIR / f"{w.id}.{w.i}.model", content) |
|
worker.models_received += 1 |
|
|
|
if worker.models_received == N_SITES: |
|
LOGGER.info(f"received {worker.models_received} local models, scheduling aggregation") |
|
background.add_task(task_aggregate, worker=worker) |
|
|
|
return worker.identify() |
|
|
|
|
|
def task_start_workers(worker: Worker): |
|
"""Background task of the coordinator to launch the work on all the workers with data.""" |
|
|
|
LOGGER.info("all workers ready, sending start command") |
|
|
|
# assign numeric id |
|
worker_list: list[WorkerId] = list() |
|
for i, w in enumerate(worker.workers.values()): |
|
w.i = i |
|
worker_list.append(w) |
|
|
|
wl: WorkerList = WorkerList( |
|
coordinator=worker.identify(), |
|
workers=worker_list, |
|
) |
|
|
|
for w in worker_list: |
|
r: httpx.Response = httpx.post( |
|
f"{w.url}/start/{worker.id}", |
|
content=wl.model_dump_json(), |
|
) |
|
r.raise_for_status() |
|
|
|
|
|
def task_train_local_model(worker: Worker): |
|
"""Background task of a worker that load local data and train the local model.""" |
|
|
|
LOGGER.info("loading data") |
|
worker.load(Path(PATH_DATA)) |
|
|
|
LOGGER.info("training...") |
|
worker.train() |
|
|
|
trees: list[DecisionTreeClassifier] = worker.share_trees() |
|
dump(WORKING_DIR / f"{worker.id}.{worker.i}.trees", trees) |
|
LOGGER.info("saved local trees to disk") |
|
|
|
for _, w in worker.workers.items(): |
|
# distribute the local trees to other workers |
|
if w.id == worker.id: |
|
continue |
|
|
|
LOGGER.info(f"sending trees to {w.id}") |
|
r: httpx.Response = httpx.post( |
|
f"{w.url}/trees/{worker.id}", |
|
content=pickle.dumps(trees), |
|
) |
|
|
|
r.raise_for_status() |
|
|
|
LOGGER.info("training task completed, local trees sent to other workers") |
|
|
|
|
|
def task_local_evaluation(worker: Worker): |
|
"""Background task for a worker that will evaluate other decison trees on local data.""" |
|
|
|
while not os.path.exists(WORKING_DIR / f"{worker.id}.{worker.i}.trees"): |
|
LOGGER.warn(f"waiting for local training of {worker.id}.{worker.i}.trees (1 second)") |
|
time.sleep(1.0) |
|
|
|
LOGGER.warn("all trees available, starting confusion matirces creation") |
|
|
|
for _, w in worker.workers.items(): |
|
# evaluates the decision trees on local data, then share the obtained confusion matrix |
|
|
|
LOGGER.info(f"evaluating matrix for {w.id} ({w.i})") |
|
trees: list[DecisionTreeClassifier] = load(WORKING_DIR / f"{w.id}.{w.i}.trees") |
|
|
|
cm = worker.evaluate(trees) |
|
worker.update(w.i, cm) |
|
|
|
dump(WORKING_DIR / f"{w.id}.{w.i}.cm", cm) |
|
|
|
if w.id == worker.id: |
|
LOGGER.info(f"local matrix {w.id} ({w.i}) available") |
|
else: |
|
LOGGER.info(f"sending local matrix to {w.id} ({w.i})") |
|
# send to other workers |
|
httpx.post( |
|
f"{w.url}/matrix/{worker.id}", |
|
content=pickle.dumps(cm), |
|
) |
|
|
|
LOGGER.info("matrices sent to all workers") |
|
|
|
|
|
def task_build_local_model(worker: Worker): |
|
"""Background task of the worker used to build a local model from.""" |
|
|
|
while not os.path.exists(WORKING_DIR / f"{worker.id}.{worker.i}.cm"): |
|
LOGGER.warn("local matrix not available, waiting 1 second") |
|
time.sleep(1.0) |
|
|
|
LOGGER.info("creating local model") |
|
|
|
local_model: Model = worker.get_local_model() |
|
|
|
content = pickle.dumps(local_model) |
|
|
|
dump(WORKING_DIR / f"{worker.id}.{worker.i}.model", content) |
|
|
|
LOGGER.info("sending local model to coordinator") |
|
|
|
httpx.post( |
|
f"{worker.coordinator.url}/done/{worker.id}", |
|
content=content, |
|
) |
|
|
|
LOGGER.info("local work done") |
|
|
|
|
|
def task_aggregate(worker: Worker): |
|
"""Background task of the coordinator used to build the global model.""" |
|
|
|
LOGGER.info("start global model aggregation") |
|
|
|
local_models: list[Model] = list() |
|
for _, w in worker.workers.items(): |
|
local_model: Model = load(WORKING_DIR / f"{w.id}.{w.i}.model") |
|
local_models.append(local_model) |
|
|
|
global_model = worker.aggregate(local_models) |
|
|
|
dump(WORKING_DIR / "global.model", global_model) |
|
|
|
LOGGER.info("global model aggregation completed") |
|
|
|
|
|
# just to maintain traditions :) |
|
app = worker |