Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Created March 9, 2019 15:00
Show Gist options
  • Save MLWhiz/806b36f4c79dd1b8d4600ca6feb9612d to your computer and use it in GitHub Desktop.
Save MLWhiz/806b36f4c79dd1b8d4600ca6feb9612d to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class CNN_Text(nn.Module):
def __init__(self):
super(CNN_Text, self).__init__()
filter_sizes = [1,2,3,5]
num_filters = 36
self.embedding = nn.Embedding(max_features, embed_size)
self.embedding.weight = nn.Parameter(torch.tensor(embedding_matrix, dtype=torch.float32))
self.embedding.weight.requires_grad = False
self.convs1 = nn.ModuleList([nn.Conv2d(1, num_filters, (K, embed_size)) for K in filter_sizes])
self.dropout = nn.Dropout(0.1)
self.fc1 = nn.Linear(len(Ks)*num_filters, 1)
def forward(self, x):
x = self.embedding(x)
x = x.unsqueeze(1)
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1]
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
x = torch.cat(x, 1)
x = self.dropout(x)
logit = self.fc1(x)
return logit
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment