Skip to content

Instantly share code, notes, and snippets.

@koyo922
Created January 5, 2020 09:32
Show Gist options
  • Save koyo922/a443b0ec1a7b1cfb17e4b11d660819f8 to your computer and use it in GitHub Desktop.
Save koyo922/a443b0ec1a7b1cfb17e4b11d660819f8 to your computer and use it in GitHub Desktop.
demo for GloVe
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 expandtab number
"""
glove的简易实现,包括
- 准备训练数据
- 训练词向量
- 实验词向量观察语义空间的降维效果
参考:
- 数据准备 https://gist.github.com/MatthieuBizien/de26a7a2663f00ca16d8d2558815e9a6
- 建模 https://github.com/kefirski/pytorch_GloVe/blob/master/GloVe/glove.py
- 绘图 https://nlpython.com/implementing-glove-model-with-pytorch/
Authors: qianweishuo<[email protected]>
Date: 2020/1/4 下午8:28
"""
import logging
import re
from collections import Counter
import sklearn.datasets
import torch
from functional import seq # pip install pyfunctional 注意包名和库名不一致
from sklearn.decomposition import PCA
from torch import nn
from torch.autograd import Variable
from torch.nn import Parameter
from torch.nn.functional import mse_loss
from torch.optim import Adagrad
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm, trange
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s')
logger = logging.getLogger() # get root logger
class GloveDataset(Dataset):
""" 将一份英文语料切词,然后组织成可供 GloVe训练用的 共现矩阵 """
REGEX_WORD = re.compile(r"\b[a-zA-Z]{2,}\b")
def __init__(self, docs, min_word_occurences=10, oov_token='<oov>', window_size=5):
# 数值化
docs_tok = [self.REGEX_WORD.findall(doc.lower()) for doc in docs] # docs tokenized
word_counter = {w: c for w, c in Counter(w for d in docs_tok for w in d).items()
if c > min_word_occurences} # 比 seq快一倍
w2i = {oov_token: 0}
docs_tok_id = [[w2i.setdefault(w, len(w2i)) if w in word_counter else 0
for w in doc] for doc in docs_tok] # docs tokenized, in id
self.w2i, self.i2w = w2i, seq(w2i.items()).order_by(lambda w_i: w_i[1]).smap(lambda w, i: w).to_list()
self.n_words = len(w2i) # 注意不是 len(word_counter), 否则缺个OOV, 越界
# 统计共现矩阵
comatrix = Counter()
for words_id in tqdm(docs_tok_id, desc='docs2comtx'):
for i, w1 in enumerate(words_id): # 注意窗口限制
for j, w2 in enumerate(words_id[i + 1: i + window_size], start=i + 1):
comatrix[(w1, w2)] += 1 / (j - i)
# 从共现矩阵中提取训练样本: (中心词A的下标, 邻居词B的下标) -> A和B的"共现值"
logger.info('extracting (a_word, b_words, co_score) from comatrix')
a_words, b_words, co_score = zip(*((left, right, x) for (left, right), x in comatrix.items()))
self.L_words = torch.LongTensor(a_words)
self.R_words = torch.LongTensor(b_words)
self.Y = torch.FloatTensor(co_score)
def __len__(self):
return len(self.Y)
def __getitem__(self, item):
return self.L_words[item], self.R_words[item], self.Y[item]
class GloVe(nn.Module):
def __init__(self, embed_size=300, y_max=100, alpha=0.75):
super().__init__()
self.embed_size = embed_size
self.y_max, self.alpha = y_max, alpha
self.a_vecs: Variable = None # 模型参数的shape依赖于训练数据; 放到 build_model_from_dataset()中初始化
self.b_vecs: Variable = None
self.a_bias: Variable = None
self.b_bias: Variable = None
self.i2w, self.w2i = [], dict() # 同理,词表的实际赋值也发生在 build_model_from_dataset()中
def fit(self, dataset: GloveDataset, lr=0.05, batch_size=512, n_epochs=3):
self.build_model_from_dataset(dataset) # 根据训练数据集来初始化模型参数
optimizer = Adagrad(self.parameters(), lr=lr)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8) # 0进程方便调试
for epoch in trange(n_epochs, desc='epoch'):
for batch_idx, (L, R, Y) in tqdm(enumerate(dataloader), desc='batch', total=len(dataset) // batch_size):
loss = self(L.cuda(), R.cuda(), Y.cuda())
if batch_idx % 100 == 0:
logger.info('epoch/batch %03d/%03d, loss = %6.3f', epoch + 1, batch_idx + 1, loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
logger.info('[DONE] training model')
def build_model_from_dataset(self, dataset: GloveDataset, std=0.01):
get_var_by_shape = lambda *shape: Parameter(torch.randn(*shape).mul(std).cuda(), requires_grad=True)
self.a_vecs = get_var_by_shape(dataset.n_words, self.embed_size)
self.b_vecs = get_var_by_shape(dataset.n_words, self.embed_size)
self.a_bias = get_var_by_shape(dataset.n_words, )
self.b_bias = get_var_by_shape(dataset.n_words, )
self.i2w, self.w2i = dataset.i2w, dataset.w2i
def forward(self, L, R, Y):
W = Y.div(self.y_max).pow(self.alpha).clamp_max(1.0) # 根据"共现值Y"来确定样本权重, 即公式中的 f(X_{ij})
pred = torch.einsum('nd,nd->n', self.a_vecs[L], self.b_vecs[R]) + self.a_bias[L] + self.b_bias[R]
target = (Y + 1).log() # 注意加一,避免 log(0)溢出
return W @ mse_loss(pred, target, reduction='none') # 注意 reduction='none'
@property
def embeddings(self): # 返回中心向量 与 邻居向量 的和; 简单粗暴
return self.a_vecs + self.b_vecs
def show_vec_space(self, n_show_vecs=300):
# 建议先PCA再TSNE; 如果直接TSNE会非常慢, 用metric='euclidean'也会非常慢
embed_pca = PCA(n_components=4).fit_transform(self.embeddings[:n_show_vecs, :].cpu().detach().numpy())
embed_tsne = TSNE(metric='euclidean', verbose=1, n_jobs=4).fit_transform(embed_pca)
# 在Jupyter Notebook中要使用 %matplotlib inline
fig, ax = plt.subplots(figsize=(20, 14))
for idx in range(n_show_vecs):
x, y = embed_tsne[idx, :]
ax.scatter(x, y, color='steelblue')
ax.annotate(self.i2w[idx], (x, y), alpha=0.7)
def main(n_docs=1000, n_epochs=3, batch_size=512, n_show_vecs=100):
# import os; os.environ['https_proxy'] = 'ip:port' # 必要时 科学上网
logger.info("Fetching data")
newsgroup = sklearn.datasets.fetch_20newsgroups(remove=('headers', 'footers', 'quotes'))
logger.info("Building dataset")
glove_data = GloveDataset(newsgroup.data[:n_docs])
logger.info("training GloVe model")
glove: GloVe = GloVe()
glove.fit(glove_data, n_epochs=n_epochs, batch_size=batch_size)
logger.info("showing 2D word vector space")
glove.show_vec_space(n_show_vecs=n_show_vecs)
if __name__ == "__main__":
# # 这句解决 RuntimeError: CUDA error: initialization error 但是不能被重复调用
# # 详见 https://blog.csdn.net/lwc5411117/article/details/83272862
# torch.multiprocessing.set_start_method('spawn')
# 另外,如果在DataSet/DataLoader中未将任何数据放到CUDA,上句可以不写。详见 http://xcx1024.com/ArtInfo/1772678.html
main(n_docs=200)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment