Created
March 31, 2019 18:28
-
-
Save hereismari/c6c76ed8933b6cfed785d6c95b34e5fe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import pandas as pd | |
import numpy as np | |
import syft as sy | |
import copy | |
hook = sy.TorchHook(torch) | |
from torch import nn | |
import torch.nn.functional as F | |
from torch import optim | |
# Creating Virtual workers | |
bob = sy.VirtualWorker(hook, id="bob") | |
alice = sy.VirtualWorker(hook, id="alice") | |
secure_worker = sy.VirtualWorker(hook, id="secure_worker") | |
bob.add_workers([alice, secure_worker]) | |
alice.add_workers([bob, secure_worker]) | |
secure_worker.add_workers([alice, bob]) | |
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1.]]) | |
Y = np.array([0, 0, 1, 1.]) | |
data = torch.from_numpy(X) | |
target = torch.from_numpy(Y) | |
data,target=data.type(torch.FloatTensor),target.type(torch.LongTensor) | |
data_length, data_width=data.shape | |
# Sending Data | |
bobs_data = data[0:int(data_length/2)].send(bob) | |
bobs_target = target[0:int(data_length/2)].send(bob) | |
alices_data = data[int(data_length/2):].send(alice) | |
alices_target = target[int(data_length/2):].send(alice) | |
#Definition of Model | |
class my_network(torch.nn.Module): | |
def __init__(self): | |
super(my_network, self).__init__() | |
self.fc1 = nn.Linear(data_width, 200) | |
self.fc2=nn.Linear(200,100) | |
self.fc3=nn.Linear(100,2) | |
def forward(self,input_): | |
a1 = self.fc1(input_) | |
a1=F.relu(self.fc1(input_)) | |
a1=F.relu(self.fc2(a1)) | |
y=self.fc3(a1) | |
return F.softmax(y, dim=1) | |
#Other model parameters | |
model=my_network() | |
loss= torch.nn.CrossEntropyLoss() | |
#Training | |
iterations = 100 | |
worker_iters = 5 | |
for a_iter in range(iterations): | |
bobs_model = model.copy().send(bob) | |
alices_model = model.copy().send(alice) | |
bobs_opt = optim.SGD(params=bobs_model.parameters(),lr=0.1) | |
alices_opt = optim.SGD(params=alices_model.parameters(),lr=0.1) | |
for wi in range(worker_iters): | |
# Train Bob's Model | |
bobs_opt.zero_grad() | |
bobs_pred = bobs_model.forward(bobs_data) | |
bobs_loss = loss(bobs_pred, bobs_target) | |
bobs_loss.backward() | |
bobs_opt.step() | |
bobs_loss = bobs_loss.get().data | |
# Train Alice's Model | |
alices_opt.zero_grad() | |
alices_pred = alices_model.forward(alices_data) | |
alices_loss = loss(alices_pred, alices_target) | |
alices_loss.backward() | |
alices_opt.step() | |
alices_loss = alices_loss.get().data | |
alices_model.move(secure_worker) | |
bobs_model.move(secure_worker) | |
model.fc1.weight.data.set_(((alices_model.fc1.weight.data + bobs_model.fc1.weight.data) / 2).get()) | |
model.fc1.bias.data.set_(((alices_model.fc1.bias.data + bobs_model.fc1.bias.data) / 2).get()) | |
model.fc2.weight.data.set_(((alices_model.fc2.weight.data + bobs_model.fc2.weight.data) / 2).get()) | |
model.fc2.bias.data.set_(((alices_model.fc2.bias.data + bobs_model.fc2.bias.data) / 2).get()) | |
model.fc3.weight.data.set_(((alices_model.fc3.weight.data + bobs_model.fc3.weight.data) / 2).get()) | |
model.fc3.bias.data.set_(((alices_model.fc3.bias.data + bobs_model.fc3.bias.data) / 2).get()) | |
print("Bob:" + str(bobs_loss) + " Alice:" + str(alices_loss)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment