Created
March 26, 2019 19:06
-
-
Save allenanie/9a1b13aea3037d391a885cdb0d01dc8e to your computer and use it in GitHub Desktop.
PyTorch Regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Traditional python regression packages like sklearn and statsmodel can't handle number of examples as large as >1M | |
# or when the feature space | |
# Currently this method uses mini-batch gradient optimization method (Adam) | |
# We also have a NullLogit model that only has intercept (used to compute pseudo R-squred for Logit model) | |
import torch | |
from torch.utils.data import TensorDataset, DataLoader, RandomSampler | |
import torch.nn as nn | |
from scipy.spatial.distance import cosine | |
import numpy as np | |
# from tqdm import tqdm | |
from tqdm import tqdm_notebook as tqdm | |
class Logit(nn.Module): | |
def __init__(self, x_dim, lr, cuda_id=-1): | |
super(Logit, self).__init__() | |
self.transform = nn.Sequential( | |
nn.Linear(in_features=x_dim, out_features=1), | |
nn.Sigmoid() | |
) | |
self.critereon = nn.BCELoss() | |
# need to move to cuda before optimizer | |
if cuda_id != -1: | |
self.transform =self.transform.cuda(cuda_id) | |
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) | |
def load_train_data(self, X, Y, batch_size=512): | |
self.train_data = TensorDataset(X, Y) | |
self.train_sampler = RandomSampler(self.train_data) # can use a sequential sampler | |
self.train_dataloader = DataLoader(self.train_data, sampler=self.train_sampler, | |
batch_size=batch_size) | |
def fit(self, epoch=5, silent=True, cuda_id=-1): | |
self.train() | |
for e in range(epoch): | |
print("epoch {}".format(e)) | |
t = tqdm(iter(self.train_dataloader), leave=True, total=len(self.train_dataloader)) | |
for x, y in t: | |
if cuda_id != -1: | |
x = x.cuda(cuda_id) | |
y = y.cuda(cuda_id) | |
v_hat = torch.squeeze(self.transform(x)) | |
loss = self.critereon(v_hat, y) | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
#if not silent: | |
#print('epoch {}: {}'.format(e, loss.data.item())) | |
t.set_description('ML (loss=%g)' % loss.data.item()) | |
print(loss.data.item()) | |
print("model fitted to data") | |
class NullLogitModel(nn.Module): | |
def __init__(self, x_dim, lr, cuda_id=-1): | |
super(NullLogitModel, self).__init__() | |
self.null_intercept = nn.Parameter(torch.FloatTensor(x_dim)) | |
self.critereon = nn.BCELoss() | |
self.sig = nn.Sigmoid() | |
# need to move to cuda before optimizer | |
if cuda_id != -1: | |
self.null_intercept = self.null_intercept.cuda(cuda_id) | |
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) | |
def load_train_data(self, X, Y, batch_size=512): | |
self.train_data = TensorDataset(X, Y) | |
self.train_sampler = RandomSampler(self.train_data) # can use a sequential sampler | |
self.train_dataloader = DataLoader(self.train_data, sampler=self.train_sampler, | |
batch_size=batch_size) | |
def fit(self, epoch=5, silent=True, cuda_id=-1): | |
self.train() | |
for e in range(epoch): | |
print("epoch {}".format(e)) | |
t = tqdm(iter(self.train_dataloader), leave=False, total=len(self.train_dataloader)) | |
for x, y in t: | |
if cuda_id != -1: | |
x = x.cuda(cuda_id) | |
y = y.cuda(cuda_id) | |
v_hat = torch.squeeze(self.sig((x + self.null_intercept).sum(1))) | |
loss = self.critereon(v_hat, y) | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
#if not silent: | |
#print('epoch {}: {}'.format(e, loss.data.item())) | |
t.set_description('ML (loss=%g)' % loss.data.item()) | |
print("final loss = {}".format(loss.data.item())) | |
print("model fitted to data") | |
model = Logit(Countries_np.shape[1], 1e-3) | |
model = model.cuda(6) | |
model.load_train_data(Countries_th, torch.tensor(Y).float(), 1024 * 4) | |
model.fit(epoch=2, silent=False, cuda_id=6) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment