Skip to content

Instantly share code, notes, and snippets.

View jvmncs's full-sized avatar
📖

jvmncs

📖
View GitHub Profile
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Converting the data from numpy to tf.Tensor in order to have PySyft functionalities.
x_train, y_train = tf.convert_to_tensor(x_train), tf.convert_to_tensor(y_train)
x_test, y_test = tf.convert_to_tensor(x_test), tf.convert_to_tensor(y_test)
# Send data to Alice (for demonstration purposes)
x_train_ptr = x_train.send(alice)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
# Compile with optimizer, loss and metrics
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
model_ptr = model.send(alice)
print(model_ptr)
# ==> (Wrapper)>[ObjectPointer | me:random_id1 -> alice:random_id2]
model_ptr.fit(x_train_ptr, y_train_ptr, epochs=2)
# ==> Train on 60000 samples
# Epoch 1/2
# 60000/60000 [==============================] - 2s 36us/sample - loss: 0.3008 - accuracy: 0.9129
# Epoch 2/2
# 60000/60000 [==============================] - 2s 32us/sample - loss: 0.1449 - accuracy: 0.9569
@jvmncs
jvmncs / paillier.py
Last active August 18, 2020 22:01
Paillier Aggregation in TensorFlow Federated
import tensorflow as tf
import tensorflow_federated as tff
from federated_aggregations import paillier
paillier_factory = paillier.local_paillier_executor_factory()
paillier_context = tff.framework.ExecutionContext(paillier_factory)
tff.framework.set_default_context(paillier_context)
# data from 5 clients
x = [np.array([i, i + 1], dtype=np.int32) for i in range(5)]
@jvmncs
jvmncs / equinox_inaxes.py
Last active May 13, 2022 15:47
eqx.filter_vmap failing to respect in_axes kwarg
import equinox as eqx
import jax
import numpy as np
def func(x):
return x + x
v_func_jax = jax.vmap(func, in_axes=0)
v_func_eqx = eqx.filter_vmap(func, in_axes=0)
@jvmncs
jvmncs / jax-poetry.sh
Created January 25, 2024 19:44
spinning up cuda-enabled jax in a poetry project (2024 jan)
#!/bin/bash
# Add jaxlib source with priority explicit
poetry source add jaxlib https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --priority explicit
# Add jaxlib package with specified version and extras
poetry add jaxlib~=0.4.23 --extras="cuda12.cudnn89" --source=jaxlib
# Add jax source with priority explicit
poetry source add jax https://storage.googleapis.com/jax-releases/jax_releases.html --priority explicit
# Add jax package with specified version and extras
@jvmncs
jvmncs / cuda-shell.nix
Last active January 30, 2025 20:36
uv-friendly devShell for CUDA-enabled PyTorch/Jax
# rename this file to flake.nix, put it next to a "use flake" .envrc file in your project folder
{
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
};
outputs = { self, nixpkgs }:
let
system = "x86_64-linux"; # adjust if needed
pkgs = import nixpkgs {