This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# rename this file to flake.nix, put it next to a "use flake" .envrc file in your project folder | |
{ | |
inputs = { | |
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; | |
}; | |
outputs = { self, nixpkgs }: | |
let | |
system = "x86_64-linux"; # adjust if needed | |
pkgs = import nixpkgs { |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
# Add jaxlib source with priority explicit | |
poetry source add jaxlib https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --priority explicit | |
# Add jaxlib package with specified version and extras | |
poetry add jaxlib~=0.4.23 --extras="cuda12.cudnn89" --source=jaxlib | |
# Add jax source with priority explicit | |
poetry source add jax https://storage.googleapis.com/jax-releases/jax_releases.html --priority explicit | |
# Add jax package with specified version and extras |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import equinox as eqx | |
import jax | |
import numpy as np | |
def func(x): | |
return x + x | |
v_func_jax = jax.vmap(func, in_axes=0) | |
v_func_eqx = eqx.filter_vmap(func, in_axes=0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tensorflow_federated as tff | |
from federated_aggregations import paillier | |
paillier_factory = paillier.local_paillier_executor_factory() | |
paillier_context = tff.framework.ExecutionContext(paillier_factory) | |
tff.framework.set_default_context(paillier_context) | |
# data from 5 clients | |
x = [np.array([i, i + 1], dtype=np.int32) for i in range(5)] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model_ptr.fit(x_train_ptr, y_train_ptr, epochs=2) | |
# ==> Train on 60000 samples | |
# Epoch 1/2 | |
# 60000/60000 [==============================] - 2s 36us/sample - loss: 0.3008 - accuracy: 0.9129 | |
# Epoch 2/2 | |
# 60000/60000 [==============================] - 2s 32us/sample - loss: 0.1449 - accuracy: 0.9569 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model_ptr = model.send(alice) | |
print(model_ptr) | |
# ==> (Wrapper)>[ObjectPointer | me:random_id1 -> alice:random_id2] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = tf.keras.models.Sequential([ | |
tf.keras.layers.Flatten(input_shape=(28, 28)), | |
tf.keras.layers.Dense(128, activation='relu'), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(10, activation='softmax') | |
]) | |
# Compile with optimizer, loss and metrics | |
model.compile(optimizer='adam', | |
loss='sparse_categorical_crossentropy', |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mnist = tf.keras.datasets.mnist | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
x_train, x_test = x_train / 255.0, x_test / 255.0 | |
# Converting the data from numpy to tf.Tensor in order to have PySyft functionalities. | |
x_train, y_train = tf.convert_to_tensor(x_train), tf.convert_to_tensor(y_train) | |
x_test, y_test = tf.convert_to_tensor(x_test), tf.convert_to_tensor(y_test) | |
# Send data to Alice (for demonstration purposes) | |
x_train_ptr = x_train.send(alice) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
x = tf.expand_dims(id[0], 0) | |
# Initialize the weight | |
w_init = tf.initializers.glorot_normal() | |
w = tf.Variable(w_init(shape=(2, 1), dtype=tf.float32)).send(alice) | |
z = tf.matmul(x, w) | |
# Manual differentiation & update | |
dzdx = tf.transpose(x) | |
w.assign_sub(dzdx) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
y_ptr = x_ptr + x_ptr | |
y = tf.reshape(y_ptr, shape=[2, 2]) | |
id = tf.constant([[1., 0.], [0., 1.]]).send(alice) | |
z = tf.matmul(y, id).get() | |
print(z) | |
# ==> tf.Tensor([[2. 4.] | |
# [6. 8.]], shape=(2, 2), dtype=float32) |
NewerOlder