Created
February 26, 2020 03:59
-
-
Save alexmlamb/f5c4241040fad812f39f1381910f4ca9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import torch | |
| import random | |
| from torch.autograd import Variable, grad | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| from pylab import rcParams | |
| rcParams['figure.figsize'] = 10, 10 | |
| def sample_moons(n_samples=100, noise=0): | |
| n_per_class = n_samples // 2 | |
| pi = 3.141592653589793 | |
| outer_grid = torch.rand(n_per_class, 1) * pi | |
| outer_circ_x = torch.cos(outer_grid) | |
| outer_circ_y = torch.sin(outer_grid) | |
| inner_grid = torch.rand(n_per_class, 1) * pi | |
| inner_circ_x = 1 - torch.cos(inner_grid) | |
| inner_circ_y = 1 - torch.sin(inner_grid) - .5 | |
| X = torch.cat((torch.cat((outer_circ_x, inner_circ_x), 0), | |
| torch.cat((outer_circ_y, inner_circ_y), 0)), 1) | |
| y = torch.cat((torch.zeros(n_per_class), | |
| torch.ones(n_per_class)), 0) | |
| X += torch.randn(X.size()) * noise | |
| xlst = [] | |
| ylst = [] | |
| std = 0.2 | |
| #for j in range(0,30): | |
| # xlst.append([random.gauss(-8,std), random.gauss(-8,std)]) | |
| # ylst.append(1) | |
| #for j in range(0,30): | |
| # xlst.append([random.gauss(8,std), random.gauss(8,std)]) | |
| # ylst.append(0) | |
| #xlst.append([6,6]) | |
| #ylst.append(1) | |
| #X = torch.Tensor(xlst).float() | |
| #y = torch.Tensor(ylst) | |
| print(X.shape, y.shape) | |
| return X.detach().cuda(), y.view(-1, 1).detach().cuda() | |
| def sample_linear(n_samples=100, noise=0): | |
| n_per_class = n_samples // 2 | |
| first = torch.rand(n_per_class, 2) * 1.0 - 0.5 | |
| second = torch.rand(n_per_class, 2) * 1.0 - 0.5 | |
| first[:,0] *= 55.0 | |
| second[:,0] *= 55.0 | |
| first[:,1] *= 5.1 | |
| second[:,1] *= 5.1 | |
| first[:,1] -= 3.0 | |
| second[:,1] += 3.0 | |
| #second[:10,0] *= 1.0 | |
| #second[10:,0] *= 1.0 | |
| #second[:10,0] -= 20.0 | |
| #second[10:,0] += 15.0 | |
| X = torch.cat((first, second), 0) | |
| y = torch.cat((torch.zeros(n_per_class), | |
| torch.ones(n_per_class)), 0) | |
| X += torch.randn(X.size()) * noise | |
| print(X.shape, y.shape) | |
| return X.cuda(), y.cuda() | |
| def sample_twospirals(n_samples=100, noise=0): | |
| np.random.seed(42) | |
| n_points = n_samples//2 | |
| n = np.sqrt(np.random.rand(n_points,1)) * 780 * (2*np.pi)/360 | |
| d1x = -np.cos(n)*n + np.random.rand(n_points,1) * noise | |
| d1y = np.sin(n)*n + np.random.rand(n_points,1) * noise | |
| X, Y = (np.vstack((np.hstack((d1x,d1y)),np.hstack((-d1x,-d1y)))), | |
| np.hstack((np.zeros(n_points),np.ones(n_points)))) | |
| print(X.shape, Y.shape) | |
| return torch.from_numpy(X.astype('float32')).cuda(), torch.from_numpy(Y.astype('float32')).cuda() | |
| def plot_moons(net, x_labeled, y_labeled, x_unlabeled, meshres): | |
| colors = ['blue' if yi == 0 else 'red' for yi in y_labeled] | |
| #plt.scatter(x_unlabeled[:, 0], x_unlabeled[:, 1], c="gray", alpha=.75) | |
| plt.scatter(x_labeled[:, 0].data.cpu().numpy(), x_labeled[:, 1].data.cpu().numpy(), c=colors, s=100) | |
| x_min, x_max = x_unlabeled[:, 0].min() - 0.25, x_unlabeled[:, 0].max() + 0.25 | |
| y_min, y_max = x_unlabeled[:, 1].min() - 0.25, x_unlabeled[:, 1].max() + 0.25 | |
| x_min = -5. | |
| y_min = -5. | |
| x_max = 5.0 | |
| y_max = 5.0 | |
| xx, yy = np.meshgrid(np.arange(x_min, x_max, meshres), np.arange(x_min, x_max, meshres)) | |
| Z = net(torch.Tensor(np.c_[xx.ravel(), yy.ravel()]).cuda()).detach().cpu().numpy() | |
| Z = Z.reshape(xx.shape) | |
| plt.imshow(Z, extent=(x_min, x_max, y_min, y_max), cmap='jet', interpolation='bilinear', origin='lower', alpha=0.5) | |
| #plt.contour(xx, yy, Z, levels=[0.01, 0.1,.3,.5,0.7, 0.9, 0.99], linewidths=[3], cmap='jet', origin='lower') | |
| plt.xlim(x_min, x_max) | |
| plt.ylim(y_min, y_max) | |
| #plt.ylim(-10,20) | |
| plt.savefig('mymoon.png') | |
| plt.clf() | |
| def mixup(x, y, a=1): | |
| #l = torch.from_numpy(np.random.beta(a, a, size=(x.size(0),1)).astype('float32')).cuda() | |
| #if random.uniform(0,1) < 0.5: | |
| # u = torch.from_numpy(np.random.normal(0.0, 6.0, size=(x.size(0),1)).astype('float32')).cuda() | |
| #else: | |
| # u = torch.from_numpy(np.random.normal(0.0, 0.2, size=(x.size(0),1)).astype('float32')).cuda() | |
| #l = np.random.beta(a,a) | |
| #l = np.addbroadcast(l, 1) | |
| #print('lambda', l) | |
| extrapolate = False | |
| if extrapolate: | |
| u = np.random.beta(a+0.01,a+0.01) | |
| p = torch.randperm(x.size(0)).cuda() | |
| direc = x[p] - x | |
| x_mix = x + u*(direc)#/direc.norm(2) | |
| d1 = (torch.abs(x_mix - x)).sum() | |
| d2 = (torch.abs(x_mix - x[p])).sum() | |
| mr = (1/(d1+0.0001)) / (1/(d1 + 0.0001) + 1/(d2+0.0001)) | |
| y_mix = mr*y + (1-mr)*y[p] | |
| y_mix = Variable(y_mix.data, requires_grad=False) | |
| else: | |
| l = np.random.beta(a,a) | |
| p = torch.randperm(x.size(0)).cuda() | |
| x_mix = l * x + (1 - l) * x[p] | |
| y_mix = l * y + (1 - l) * y[p] | |
| return x_mix, y_mix | |
| def train_net(x_labeled, | |
| y_labeled, | |
| x_unlabeled, | |
| n_hiddens=512, | |
| n_iterations=3200, | |
| lamba=1): | |
| net1 = torch.nn.Sequential( | |
| torch.nn.Linear(x_labeled.size(1), n_hiddens), | |
| #torch.nn.Dropout(0.5), | |
| torch.nn.LeakyReLU(), | |
| torch.nn.Linear(n_hiddens, n_hiddens), | |
| #torch.nn.Dropout(0.5), | |
| torch.nn.LeakyReLU(), | |
| torch.nn.Linear(n_hiddens, n_hiddens), | |
| #torch.nn.Dropout(0.5), | |
| torch.nn.LeakyReLU(), | |
| torch.nn.Linear(n_hiddens, n_hiddens), | |
| #torch.nn.Dropout(0.5), | |
| torch.nn.LeakyReLU(), | |
| torch.nn.Linear(n_hiddens, 2)) | |
| bn = torch.nn.BatchNorm1d(2, affine=False).cuda() | |
| net2 = torch.nn.Sequential( | |
| torch.nn.Linear(2, n_hiddens), | |
| #torch.nn.Dropout(0.5), | |
| torch.nn.LeakyReLU(), | |
| torch.nn.Linear(n_hiddens, n_hiddens), | |
| #torch.nn.Dropout(0.5), | |
| torch.nn.LeakyReLU(), | |
| torch.nn.Linear(n_hiddens, n_hiddens), | |
| #torch.nn.Dropout(0.5), | |
| torch.nn.LeakyReLU(), | |
| torch.nn.Linear(n_hiddens, 1), | |
| torch.nn.Sigmoid()) | |
| def net(inp): | |
| #return net2(autoencoder(net1(inp))) | |
| return net2(net1(inp)) | |
| net1.cuda() | |
| net2.cuda() | |
| opt = torch.optim.Adam(list(net1.parameters()) + list(net2.parameters()), lr = 0.0001, weight_decay = 1e-4) | |
| bce = torch.nn.BCELoss() | |
| mse = torch.nn.MSELoss() | |
| for iteration in range(n_iterations): | |
| opt.zero_grad() | |
| mix = "h" | |
| #mix = 'none' | |
| #mix = 'x' | |
| if mix == 'none': | |
| x_labeled_mix, y_labeled_mix = x_labeled, y_labeled | |
| x_labeled = Variable(x_labeled.data, requires_grad=True) | |
| x_labeled_use = x_labeled_mix | |
| #x_labeled_use = x_labeled_mix + torch.randn(size=x_labeled_mix.shape).cuda() * 0.1 | |
| error_labeled = bce(net2(net1((x_labeled_use))), y_labeled_mix) | |
| elif mix == 'x': | |
| a = 0.5 | |
| #x_labeled_mix, y_labeled_mix = mixup(x_labeled, y_labeled, a) | |
| #error_labeled = bce(net2(net1((x_labeled_mix))), y_labeled_mix) | |
| lamb = np.random.normal(0.5, 0.2) | |
| perm = torch.randperm(x_labeled.shape[0]) | |
| x_mix = x_labeled*lamb + x_labeled[perm] * (1-lamb) | |
| l1 = bce(net2(net1(x_mix)), y_labeled) | |
| l2 = bce(net2(net1(x_mix)), y_labeled[perm]) | |
| error_labeled = l1*lamb + l2*(1-lamb) | |
| elif mix == "h": | |
| #a = 1.0 | |
| x_labeled = Variable(x_labeled.data, requires_grad=True) | |
| h = net1(x_labeled) | |
| #h_mix, y_labeled_mix = mixup(h, y_labeled, a) | |
| #error_labeled = bce(net2(h_mix), y_labeled_mix) | |
| lamb = np.random.normal(0.5,0.2) | |
| perm = torch.randperm(h.shape[0]) | |
| h_mix = h*lamb + h[perm] * (1-lamb) | |
| l1 = bce(net2(h_mix), y_labeled) | |
| l2 = bce(net2(h_mix), y_labeled[perm]) | |
| error_labeled = l1*lamb + l2*(1-lamb) | |
| else: | |
| raise Exception('mixing option not found') | |
| error_labeled.backward() | |
| opt.step() | |
| #opt.zero_grad() | |
| #x_unlabeled, y_unlabeled = sample_moons(1024, noise=0.1) | |
| #y_unlabeled = net(x_unlabeled) | |
| #x_unlabeled_mix, y_unlabeled_mix = mixup(x_unlabeled, torch.round(y_unlabeled), 0.1) | |
| #error_unlabeled = ((net(x_unlabeled_mix) - y_unlabeled_mix)**2).sum() / x_unlabeled.size(0) | |
| #neg_entropy = 0.001 * (y_unlabeled*torch.log(0.001 + y_unlabeled) + (1-y_unlabeled)*torch.log(1 - y_unlabeled + 0.001)).sum() | |
| #conf = ((y_unlabeled - torch.round(y_unlabeled))**2).sum() / x_unlabeled.size(0) | |
| #(lamba * (error_unlabeled)).backward() | |
| #opt.step() | |
| if iteration % 100 == 0: | |
| print(iteration, error_labeled)#, error_unlabeled, "conf (high means low confidence)", conf) | |
| return net1,net2 | |
| def plot_h(net1, net2,x_labeled, bound_stat, meshres): | |
| x_labeled = Variable(x_labeled.data, requires_grad=True) | |
| h = net1(x_labeled) | |
| hr = net1(18.0*torch.rand((300,2)).cuda() - 9.0) | |
| h = h.data | |
| hr = hr.data | |
| plt.scatter(h[:h.size(0)//2,0].cpu().numpy(), h[:h.size(0)//2,1].cpu().numpy(), color='blue') | |
| plt.scatter(h[h.size(0)//2:,0].cpu().numpy(), h[h.size(0)//2:,1].cpu().numpy(), color='red') | |
| #plt.scatter(h_rec[:h.size(0)//2,0].cpu().numpy(), h_rec[:h.size(0)//2,1].cpu().numpy(), color='purple') | |
| #plt.scatter(h_rec[h.size(0)//2:,0].cpu().numpy(), h_rec[h.size(0)//2:,1].cpu().numpy(), color='orange') | |
| plt.scatter(hr[:,0].cpu().numpy(), hr[:,1].cpu().numpy(), color='black', alpha=0.08, linewidth=0.5) | |
| hr = h*1.0 | |
| p_min = min(hr[:,0].min(), hr[0,:].min()) | |
| p_max = max(hr[:,0].max(), hr[0,:].max()) | |
| xx, yy = np.meshgrid(np.arange(p_min.item() - bound_stat, p_max.item() + bound_stat, meshres), np.arange(p_min.item() - bound_stat, p_max.item() + bound_stat, meshres)) | |
| if True: | |
| model = lambda inp: net2(inp) | |
| Z = model(torch.Tensor(np.c_[xx.ravel(), yy.ravel()]).cuda()).detach().cpu().numpy() | |
| Z = Z.reshape(xx.shape) | |
| plt.imshow(Z, extent=(xx.min(), xx.max(), yy.min(), yy.max()), cmap='jet', interpolation='bilinear', origin='lower', alpha=0.5) | |
| plt.savefig('hmoon.png') | |
| plt.clf() | |
| if __name__ == "__main__": | |
| lamba = 1 | |
| torch.manual_seed(1000) | |
| #was doing 128 | |
| x_labeled, y_labeled = sample_moons(128, 0.01) | |
| x_unlabeled, y_unlabeled = sample_twospirals(1000, noise=0.1) | |
| net1,net2 = train_net(x_labeled, y_labeled, x_unlabeled, lamba=lamba) | |
| plot_moons(lambda inp: net2(net1(inp)), x_labeled, y_labeled, x_unlabeled, meshres=0.1) | |
| plot_h(net1, net2, x_labeled, bound_stat=0.2, meshres=0.2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment