This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gensim.test.utils import common_texts | |
from gensim.corpora.dictionary import Dictionary | |
from gensim.models import LdaModel | |
# Create a corpus from a list of texts | |
common_dictionary = Dictionary(common_texts) | |
common_corpus = [common_dictionary.doc2bow(text) for text in common_texts] | |
# Train the model on the corpus. | |
lda = LdaModel(common_corpus, num_topics=10) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# efficiently compute the aggregation of feature interactions | |
emb_list = [emb_item, pooled_interaction, emb_price_rank, emb_city, emb_last_item, emb_impression_index, emb_star] | |
emb_concat = torch.cat(emb_list, dim=1) | |
sum_squared = torch.pow( torch.sum( emb_concat, dim=1) , 2).unsqueeze(1) | |
squared_sum = torch.sum( torch.pow( emb_concat, 2) , dim=1).unsqueeze(1) | |
second_order = 0.5 * (squared_sum - sum_squared) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for epoch in range(1): | |
loss = train() | |
train_acc = evaluate(train_loader) | |
val_acc = evaluate(val_loader) | |
test_acc = evaluate(test_loader) | |
print('Epoch: {:03d}, Loss: {:.5f}, Train Auc: {:.5f}, Val Auc: {:.5f}, Test Auc: {:.5f}'. | |
format(epoch, loss, train_acc, val_acc, test_acc)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
embed_dim = 128 | |
from torch_geometric.nn import TopKPooling | |
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp | |
import torch.nn.functional as F | |
class Net(torch.nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = SAGEConv(embed_dim, 128) | |
self.pool1 = TopKPooling(128, ratio=0.8) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
embed_dim = 128 | |
from torch_geometric.nn import GraphConv, TopKPooling, GatedGraphConv | |
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp | |
import torch.nn.functional as F | |
class Net(torch.nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = SAGEConv(embed_dim, 128) | |
self.pool1 = TopKPooling(128, ratio=0.8) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Sequential as Seq, Linear, ReLU | |
from torch_geometric.nn import MessagePassing | |
from torch_geometric.utils import remove_self_loops, add_self_loops | |
class SAGEConv(MessagePassing): | |
def __init__(self, in_channels, out_channels): | |
super(SAGEConv, self).__init__(aggr='max') # "Max" aggregation. | |
self.lin = torch.nn.Linear(in_channels, out_channels) | |
self.act = torch.nn.ReLU() | |
self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
embed_dim = 128 | |
from torch_geometric.nn import GraphConv, TopKPooling, GatedGraphConv | |
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp | |
import torch.nn.functional as F | |
class Net(torch.nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = SAGEConv(embed_dim, 128) | |
self.pool1 = TopKPooling(128, ratio=0.8) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def evaluate(loader): | |
model.eval() | |
predictions = [] | |
labels = [] | |
with torch.no_grad(): | |
for data in loader: | |
data = data.to(device) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(): | |
model.train() | |
loss_all = 0 | |
for data in train_loader: | |
data = data.to(device) | |
optimizer.zero_grad() | |
output = model(data) | |
label = data.y.to(device) | |
loss = crit(output, label) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
embed_dim = 128 | |
from torch_geometric.nn import GraphConv, TopKPooling, GatedGraphConv | |
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp | |
import torch.nn.functional as F | |
class Net(torch.nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = GraphConv(embed_dim, 128, aggr='max') | |
self.pool1 = TopKPooling(128, ratio=0.8) |