Created
May 12, 2019 19:57
-
-
Save Skybladev2/d6b9bee329663dd972c6fda102fd4eec to your computer and use it in GitHub Desktop.
Measuring the similarity of books using TF-IDF, Doc2vec and TensorFlow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import math | |
import os | |
import pickle | |
import random | |
import re | |
import time | |
import urllib.request | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import tensorflow as tf | |
from sklearn.manifold import TSNE | |
TOP_WORDS = 2500 | |
EMBED_SIZE = 64 | |
BATCH_SIZE = 256 | |
EXAMPLE_SIZE = 3 | |
WINDOW_SIZE = 8 | |
DEMO_STEPS = 1000 | |
STATUS_STEPS = 100 | |
TOTAL_STEPS = 50000 | |
ENCODING_FILE = 'book_encodings.pkl' | |
DATA_FOLDER = 'e:\\doc2vec' | |
TF_FOLDER = 'logs' | |
BOOKS = [ | |
('https://www.gutenberg.org/files/46/46-0.txt', 'A Christmas Carol', 'Charles Dickens'), | |
('https://www.gutenberg.org/files/98/98-0.txt', 'A Tale of Two Cities', 'Charles Dickens'), | |
('https://www.gutenberg.org/files/11/11-0.txt', "Alice's Adventures in Wonderland", 'Lewis Carroll'), | |
('http://www.gutenberg.org/cache/epub/996/pg996.txt', 'Don Quixote', 'Miguel de Cervantes'), | |
('https://www.gutenberg.org/files/158/158-0.txt', 'Emma', 'Jane Austen'), | |
('http://www.gutenberg.org/cache/epub/83/pg83.txt', 'From the Earth to the Moon', 'Jules Verne'), | |
('https://www.gutenberg.org/files/1400/1400-0.txt', 'Great Expectations', 'Charles Dickens'), | |
('http://www.gutenberg.org/cache/epub/3748/pg3748.txt', 'Journey to the Center of the Earth', 'Jules Verne'), | |
('https://www.gutenberg.org/files/1342/1342-0.txt', 'Pride and Prejudice', 'Jane Austen'), | |
('http://www.gutenberg.org/cache/epub/21839/pg21839.txt', 'Sense and Sensibility', 'Jane Austen'), | |
('http://www.gutenberg.org/cache/epub/78/pg78.txt', 'Tarzan of the Apes', 'Edgar Rice Burroughs'), | |
('http://www.gutenberg.org/cache/epub/1013/pg1013.txt', 'The First Men In The Moon', 'H. G. Wells'), | |
('https://www.gutenberg.org/files/236/236-0.txt', 'The Jungle Book', 'Rudyard Kipling'), | |
('https://www.gutenberg.org/files/8147/8147-0.txt', 'The Man Who Would Be King', 'Rudyard Kipling'), | |
('https://www.gutenberg.org/files/35/35-0.txt', 'The Time Machine', 'H. G. Wells'), | |
('https://www.gutenberg.org/files/36/36-0.txt', 'The War of the Worlds', 'H. G. Wells'), | |
('http://www.gutenberg.org/cache/epub/43936/pg43936.txt', 'The Wonderful Wizard of Oz', 'L. Frank Baum'), | |
('https://www.gutenberg.org/files/12/12-0.txt', 'Through the Looking-Glass', 'Lewis Carroll'), | |
('https://www.gutenberg.org/files/120/120-0.txt', 'Treasure Island', 'Robert Louis Stevenson'), | |
('https://www.gutenberg.org/files/775/775-0.txt', 'When the Sleeper Wakes', 'H. G. Wells') | |
] | |
def books_containing(books, token): | |
return sum(1 for book in books if token in book) | |
def idf(books, token): | |
return math.log(len(books) / (1 + books_containing(books, token))) | |
def term_frequency(book, token): | |
return book[token] / len(book) | |
def tf_idf(books, book, token): | |
return term_frequency(book, token) * idf(books, token) | |
def download_books(): | |
path = os.path.join(DATA_FOLDER, 'books') | |
if not os.path.exists(path): | |
os.makedirs(path) | |
for url, _, _ in BOOKS: | |
name = url.rsplit('/', 1)[-1] | |
filename = os.path.join(path, name) | |
if not os.path.isfile(filename): | |
print('Downloading', url) | |
urllib.request.urlretrieve(url, filename) | |
def get_book_encodings(): | |
# Build authors dictionary. | |
_, _, authors = zip(*BOOKS) | |
author_set = set(authors) | |
authors = {} | |
for author in author_set: | |
authors[author] = len(authors) | |
# Load encodings from file if they exist. | |
data_file = os.path.join(DATA_FOLDER, ENCODING_FILE) | |
if os.path.isfile(data_file): | |
with open(data_file, 'rb') as f: | |
dump = pickle.load(f) | |
return dump[0], dump[1], authors | |
# Count tokens. | |
word_count = 0 | |
book_tokens = [] | |
book_counters = [] | |
reg_alpha = re.compile('[^a-z]') | |
reg_apostrophe = re.compile(r"['’]") | |
path = os.path.join(DATA_FOLDER, 'books') | |
for url, _, _ in BOOKS: | |
file = url.rsplit('/', 1)[-1] | |
with open(os.path.join(path, file), encoding='utf8') as book_file: | |
text = book_file.read() | |
tokens = reg_alpha.sub(' ', reg_apostrophe.sub('', text.lower())).split() | |
tokens = [t for t in tokens if len(t) > 1] | |
book_tokens.append(tokens) | |
book_counters.append(collections.Counter(tokens)) | |
word_count += len(tokens) | |
# Calculate TF-IDF scores. | |
vocab_set = set() | |
for book_ix, book in enumerate(book_counters): | |
print('Top tokens in:', BOOKS[book_ix][1]) | |
scores = {token: tf_idf(book_counters, book, token) for token in book} | |
sorted_tokens = sorted(scores.items(), key=lambda x: x[1], reverse=True) | |
tokens, _ = zip(*sorted_tokens[:TOP_WORDS]) | |
vocab_set = vocab_set.union(set(tokens)) | |
for token, score in sorted_tokens[:EXAMPLE_SIZE]: | |
print('>', token, round(score, 5)) | |
print() | |
# Build tokens dictionary. | |
vocab = {} | |
for token in vocab_set: | |
vocab[token] = len(vocab) | |
print('Word count:', word_count) | |
print('Vocab size:', len(vocab)) | |
print() | |
# Encode books. | |
book_encodings = [] | |
for book_ix, tokens in enumerate(book_tokens): | |
book_labels = [0] * len(BOOKS) | |
book_labels[book_ix] = 1 | |
encoding = [] | |
for token in tokens: | |
if token in vocab: | |
encoding.append(vocab[token]) | |
book_encodings.append([encoding, book_labels]) | |
with open(data_file, 'wb') as f: | |
dump = [book_encodings, len(vocab)] | |
pickle.dump(dump, f) | |
return book_encodings, len(vocab), authors | |
def generate_batch(encodings): | |
tokens_batch = [] | |
books_batch = [] | |
while len(tokens_batch) < BATCH_SIZE: | |
book_ix = random.randint(0, len(BOOKS) - 1) | |
encoding = encodings[book_ix] | |
token_ix = random.randint(0, len(encoding[0]) - WINDOW_SIZE) | |
inputs = encoding[0][token_ix:token_ix + WINDOW_SIZE] | |
tokens_batch.append(inputs) | |
books_batch.append(encoding[1]) | |
return tokens_batch, books_batch | |
def report_closest(book_embeddings, book_ix, session): | |
print('Closest to:', BOOKS[book_ix][1]) | |
norm = tf.sqrt(tf.reduce_sum(tf.square(book_embeddings), 1, keepdims=True)) | |
norm_book_embeddings = book_embeddings / norm | |
sim = session.run(tf.matmul(tf.expand_dims(norm_book_embeddings[book_ix], 0), norm_book_embeddings, transpose_b=True))[0] | |
nearest = (-sim).argsort()[1:EXAMPLE_SIZE + 1] | |
for k in range(EXAMPLE_SIZE): | |
print('> %0.3f %s' % (sim[nearest[k]], BOOKS[nearest[k]][1])) | |
print() | |
def main(): | |
# Setup environment. | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
if not os.path.exists(DATA_FOLDER): | |
os.makedirs(DATA_FOLDER) | |
# Download books. | |
download_books() | |
# Build vocab and training data. | |
book_encodings, vocab_size, authors = get_book_encodings() | |
# Token embeddings. | |
inputs = tf.placeholder(tf.int32, shape=[None, WINDOW_SIZE]) | |
token_embeddings = tf.Variable(tf.random_uniform([vocab_size, EMBED_SIZE], -1.0, 1.0)) | |
token_embedding = tf.zeros([BATCH_SIZE, EMBED_SIZE]) | |
for i in range(WINDOW_SIZE): | |
token_embedding += tf.nn.embedding_lookup(token_embeddings, inputs[:, i]) | |
# Book embeddings. | |
book_labels = tf.placeholder(tf.int32, shape=[None, len(BOOKS)]) | |
book_bias = tf.Variable(tf.constant(0.1, shape=[len(BOOKS)])) | |
book_embeddings = tf.Variable(tf.random_uniform([EMBED_SIZE, len(BOOKS)], -1.0, 1.0)) | |
book_outputs = tf.matmul(token_embedding, book_embeddings) + book_bias | |
# Calculate loss and train. | |
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=book_outputs, labels=book_labels)) | |
tf.summary.scalar('loss', loss) | |
optimizer = tf.train.AdamOptimizer() | |
global_step = tf.Variable(0, trainable=False, name='global_step') | |
train_op = optimizer.minimize(loss=loss, global_step=global_step) | |
saver = tf.train.Saver(max_to_keep=1) | |
# Train model. | |
with tf.Session() as session: | |
# Initialise summary file writer. | |
merged = tf.summary.merge_all() | |
summary_writer = tf.summary.FileWriter(os.path.join(DATA_FOLDER, TF_FOLDER), session.graph) | |
# Load or initialise model. | |
checkpoint = tf.train.latest_checkpoint(os.path.join(DATA_FOLDER, TF_FOLDER)) | |
if checkpoint: | |
print('Loading checkpoint') | |
saver.restore(session, checkpoint) | |
else: | |
session.run(tf.global_variables_initializer()) | |
session.run(tf.local_variables_initializer()) | |
step = 0 | |
start_time = time.time() | |
while step < TOTAL_STEPS: | |
tokens_batch, books_batch = generate_batch(book_encodings) | |
feed_dict = {inputs: tokens_batch, book_labels: books_batch} | |
_, step = session.run([train_op, global_step], feed_dict=feed_dict) | |
# Update status. | |
if step % STATUS_STEPS == 0: | |
current_time = time.time() | |
elapsed_time = current_time - start_time | |
time_left = TOTAL_STEPS * elapsed_time / step - elapsed_time | |
ls, summary = session.run([loss, merged], feed_dict=feed_dict) | |
print('step %d, loss %0.3f, remaining %0.2f' % (step, ls, time_left / 60)) | |
summary_writer.add_summary(summary, step) | |
# Run demo tasks. | |
if step % DEMO_STEPS == 0: | |
book_ix = int(np.random.choice(len(BOOKS), size=1)) | |
embeddings = session.run(tf.transpose(book_embeddings)) | |
report_closest(embeddings, book_ix, session, ) | |
print('Saving checkpoint\n') | |
saver.save(session, os.path.join(DATA_FOLDER, TF_FOLDER, 'doc2vec.ckpt'), global_step=step) | |
# Book report. | |
embeddings = session.run(tf.transpose(book_embeddings)) | |
for book_ix in range(len(BOOKS)): | |
report_closest(embeddings, book_ix, session) | |
# Project book embeddings onto 2D plane. | |
embeddings = session.run(tf.transpose(book_embeddings)) | |
x, y = zip(*TSNE(n_components=2, verbose=1, perplexity=len(BOOKS) / len(authors), n_iter=5000).fit_transform(embeddings)) | |
colours = [authors[BOOKS[book_ix][2]] for book_ix in range(len(BOOKS))] | |
fig, ax = plt.subplots() | |
plt.scatter(x, y, s=120 ** 2, alpha=0.6, cmap='brg', c=colours) | |
for book_ix in range(len(BOOKS)): | |
ax.annotate(BOOKS[book_ix][1] + '\n' + BOOKS[book_ix][2], (x[book_ix], y[book_ix]), horizontalalignment='center', verticalalignment='center', fontsize=10) | |
plt.show() | |
print() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment