Last active
July 20, 2018 12:16
-
-
Save rbrigden/b397e1a6f8fa03fc805a8a458a1d7714 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
class MLP(nn.Module): | |
def __init__(self, input_size, feature_categories): | |
super(MLP, self).__init__() | |
self.feature_categories = feature_categories | |
out_size = int(np.sum(feature_categories)) | |
# Two fully connected layers | |
self.fc1 = nn.Linear(input_size, 64) | |
self.fc2 = nn.Linear(64, out_size) | |
def forward(self, x): | |
z = F.relu(self.fc1(x)) | |
start = 0 | |
outputs = [] | |
# Take a softmax over each categorical variable | |
for csize in self.feature_categories: | |
category_logits = F.log_softmax(z[:, start:start+csize], dim=1) | |
outputs.append(category_logits) | |
return outputs | |
def random_labels(batch_size, feature_categories): | |
# Create random batch of index labels for each catgeory | |
return [torch.randint(fc, (batch_size,)).type(torch.LongTensor) for fc in feature_categories] | |
if __name__ == "__main__": | |
input_size = 784 | |
batch_size = 128 | |
# Let's say we have 4 variables with respective number of | |
# categories 5, 4, 7, 9. | |
feature_categories = [5, 4, 7, 9] | |
# random data | |
data = torch.randn(batch_size, input_size) | |
# random labels | |
label_idxs = random_labels(batch_size, feature_categories) | |
# init the model | |
model = MLP(input_size, feature_categories) | |
# forward_pass | |
out = model(data) | |
# Compute loss | |
losses = [F.nll_loss(feature, label) for feature, label in zip(out, label_idxs)] | |
net_loss = sum(losses) | |
# backprop | |
net_loss.backward() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment