Skip to content

Instantly share code, notes, and snippets.

@hereismari
Created March 31, 2019 18:28
Show Gist options
  • Save hereismari/c6c76ed8933b6cfed785d6c95b34e5fe to your computer and use it in GitHub Desktop.
Save hereismari/c6c76ed8933b6cfed785d6c95b34e5fe to your computer and use it in GitHub Desktop.
import torch
import pandas as pd
import numpy as np
import syft as sy
import copy
hook = sy.TorchHook(torch)
from torch import nn
import torch.nn.functional as F
from torch import optim
# Creating Virtual workers
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
secure_worker = sy.VirtualWorker(hook, id="secure_worker")
bob.add_workers([alice, secure_worker])
alice.add_workers([bob, secure_worker])
secure_worker.add_workers([alice, bob])
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1.]])
Y = np.array([0, 0, 1, 1.])
data = torch.from_numpy(X)
target = torch.from_numpy(Y)
data,target=data.type(torch.FloatTensor),target.type(torch.LongTensor)
data_length, data_width=data.shape
# Sending Data
bobs_data = data[0:int(data_length/2)].send(bob)
bobs_target = target[0:int(data_length/2)].send(bob)
alices_data = data[int(data_length/2):].send(alice)
alices_target = target[int(data_length/2):].send(alice)
#Definition of Model
class my_network(torch.nn.Module):
def __init__(self):
super(my_network, self).__init__()
self.fc1 = nn.Linear(data_width, 200)
self.fc2=nn.Linear(200,100)
self.fc3=nn.Linear(100,2)
def forward(self,input_):
a1 = self.fc1(input_)
a1=F.relu(self.fc1(input_))
a1=F.relu(self.fc2(a1))
y=self.fc3(a1)
return F.softmax(y, dim=1)
#Other model parameters
model=my_network()
loss= torch.nn.CrossEntropyLoss()
#Training
iterations = 100
worker_iters = 5
for a_iter in range(iterations):
bobs_model = model.copy().send(bob)
alices_model = model.copy().send(alice)
bobs_opt = optim.SGD(params=bobs_model.parameters(),lr=0.1)
alices_opt = optim.SGD(params=alices_model.parameters(),lr=0.1)
for wi in range(worker_iters):
# Train Bob's Model
bobs_opt.zero_grad()
bobs_pred = bobs_model.forward(bobs_data)
bobs_loss = loss(bobs_pred, bobs_target)
bobs_loss.backward()
bobs_opt.step()
bobs_loss = bobs_loss.get().data
# Train Alice's Model
alices_opt.zero_grad()
alices_pred = alices_model.forward(alices_data)
alices_loss = loss(alices_pred, alices_target)
alices_loss.backward()
alices_opt.step()
alices_loss = alices_loss.get().data
alices_model.move(secure_worker)
bobs_model.move(secure_worker)
model.fc1.weight.data.set_(((alices_model.fc1.weight.data + bobs_model.fc1.weight.data) / 2).get())
model.fc1.bias.data.set_(((alices_model.fc1.bias.data + bobs_model.fc1.bias.data) / 2).get())
model.fc2.weight.data.set_(((alices_model.fc2.weight.data + bobs_model.fc2.weight.data) / 2).get())
model.fc2.bias.data.set_(((alices_model.fc2.bias.data + bobs_model.fc2.bias.data) / 2).get())
model.fc3.weight.data.set_(((alices_model.fc3.weight.data + bobs_model.fc3.weight.data) / 2).get())
model.fc3.bias.data.set_(((alices_model.fc3.bias.data + bobs_model.fc3.bias.data) / 2).get())
print("Bob:" + str(bobs_loss) + " Alice:" + str(alices_loss))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment