Skip to content

Instantly share code, notes, and snippets.

@monk1337
Created August 22, 2021 13:07
Show Gist options
  • Save monk1337/04d1a27a079f0ed002ab22351c3a3587 to your computer and use it in GitHub Desktop.
Save monk1337/04d1a27a079f0ed002ab22351c3a3587 to your computer and use it in GitHub Desktop.
from transformers import BertTokenizer, BertModel
import torch
import numpy as np
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
class GcnData(object):
def __init__(self, label_df):
self.df = label_df
def get_embeddings(self, word):
inputs = tokenizer(word, return_tensors='pt')
outputs = model(**inputs)
result = outputs.pooler_output
return result.cpu().detach().numpy()
def label_emb(self):
""" GCN feature matrix """
label_e = []
label_order = list(self.df.columns)
for i in label_order:
label_e.append(self.get_embeddings(i).squeeze())
return np.array(label_e)
def get_adj_matrix(self):
"""GCN adj matrix
where df should be one hot of labels"""
all_categoris = list(self.df.columns)
label_freq = {}
for i in all_categoris:
label_freq[i] = self.df[i].value_counts()[1]
u = np.diag(np.ones(self.df.shape[1], dtype=bool))
adj_m = self.df.T.dot(self.df) * (~u)
adj_m = adj_m.to_numpy()
data = {'adj': adj_m, 'nums': np.array(list(label_freq.values()))}
return data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment