Skip to content

Instantly share code, notes, and snippets.

@jeremy-rutman
Last active March 9, 2020 11:17
Show Gist options
  • Select an option

  • Save jeremy-rutman/997c8fadd53dfde38f11a8e9d0ef57e3 to your computer and use it in GitHub Desktop.

Select an option

Save jeremy-rutman/997c8fadd53dfde38f11a8e9d0ef57e3 to your computer and use it in GitHub Desktop.
# attempt at 'batch-known' learning , where positive/negative is known per batch and not per example
# in this attempt, examples are positive if there are five or more ones in the input vector which is itself six random binary digits.
# not only is this underlying truth (y=(sum(x)>=5)) hidden from the loss function , the loss func will only know per batch if there was at
# least one true example in the batch (a 'positive batch'), or none (a 'negative batch')
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch
import numpy as np
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn.metrics import roc_curve, auc
n_rows = 10000
feature_dimensions = 6
features = np.random.randint(0, 2, [n_rows, feature_dimensions])
results = np.transpose((np.sum(features, axis=1) >= 5) * 1)
print('shapes x {} y {}'.format(features.shape, results.shape))
n_train = len(features) - 1000
features_train = features[0:n_train]
features_test = features[n_train:]
results_train = results[0:n_train]
results_test = results[n_train:]
print('train', features_train.shape, 'test', features_test.shape, 'restrain', results_train.shape, 'restest',
results_test.shape)
print('feat', features_train[0:10], 'res', results_train[0:10])
batchsize = 8 # chose a batch s..t abt half will have at least one pos and abt half not
## handmade NN
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = batchsize, feature_dimensions, 10, 1
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out)
)
learning_rate = 1e-4
#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=True)
# loss_fn = torch.nn.MSELoss(reduction='sum')
def loss_fn(y_pred, y):
# loss is per batch - the loss func knows only that whole batch has at least one pos example, or doesnt
# y_pred is one prediction per example
# but y is one ground truth answer per batch
batch_result = y
pred_result = max(y_pred)
loss = abs(batch_result - pred_result)
# print('loss ',loss)
return loss
def epoch():
n_positives = 0
n_batches = 0
for t in range(0,n_train,batchsize):
# Forward pass
x = torch.from_numpy(features_train[t:t + batchsize]).float()
y = torch.from_numpy(results_train[t:t + batchsize]).float()
y = max(y) # loss is per batch - the loss func knows only that whole batch has at least one true, or doesnt
# batch_result will be 1 if one or more example is positive, otherwise 0
y_pred = model(x)
loss = loss_fn(y_pred, y)
n_positives += y
n_batches +=1
if t==0:
print(t, loss.item(),'sum', y)
# Zero the gradients before running the backward pass.
# model.zero_grad()
optimizer.zero_grad()
# Backward pass: compute gradient of the loss with respect to all the learnable parameters of the model.
loss.backward()
# Update the weights
optimizer.step()
print('npos ',n_positives,'n_batches',n_batches)
n_epochs = 500
for i in range(n_epochs):
epoch()
x = torch.from_numpy(features_test).float()
y = torch.from_numpy(results_test).float()
y_pred = model(x)
# loss = loss_fn(y_pred, y).item()
# print(loss,loss/y.shape[0])
indices = np.where(y == 1)
print('x', x[0:5])
print('y', y[0:50])
print('ypred', y_pred[0:50])
print('predictions for positive inputs at indices ',indices)
print([y_pred[ind] for ind in indices])
y_np = results_test
y_pred_np = y_pred.detach().numpy()
fpr, tpr, _ = roc_curve(y_np.ravel(), y_pred_np.ravel())
roc_auc = auc(fpr, tpr)
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment