Last active
December 7, 2021 15:20
-
-
Save sbrl/8a50f458e385f84ab2170964d988fcae to your computer and use it in GitHub Desktop.
GloVe word embedding handling code for Python - improved fork of https://gist.github.com/ppope/0ff9fa359fb850ecf74d061f3072633a
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import io | |
import sys | |
from loguru import logger | |
import numpy | |
import tensorflow as tf | |
from ..polyfills.string import removeprefix, removesuffix | |
from .normalise_text import normalise as normalise_text | |
class GloVe: | |
""" | |
Manages pre-trained GloVe word vectors. | |
Ref https://www.damienpontifex.com/posts/using-pre-trained-glove-embeddings-in-tensorflow/ | |
Download pre-trained word vectors from here: https://nlp.stanford.edu/projects/glove/ | |
""" | |
def __init__(self, filepath): | |
""" | |
Initialises a new GloVe class instance. | |
filepath (string): The path to the file to load the pre-trained GloVe embeddings from. | |
""" | |
super(GloVe, self).__init__() | |
self.data = {} | |
self.word_length = None | |
self.filepath = filepath | |
self.load() | |
def load(self): | |
"""Loads the GloVe database from a given file.""" | |
print() | |
start = time.time() | |
handle = io.open(self.filepath, "r") | |
for i, line in enumerate(handle): | |
parts = line.split(" ", maxsplit=1) | |
# We do NOT strip < and > here, because we do a lookup later on that. | |
self.data[parts[0]] = list(map( | |
lambda el: float(el), | |
parts[1].split(" ") | |
)) | |
if self.word_length is None: | |
self.word_length = len(self.data[parts[0]]) | |
# Update the CLI | |
if i % 10000 == 0: | |
sys.stderr.write(f"\rLoading GloVe from '{self.filepath}': {i}...") | |
handle.close() | |
sys.stderr.write(f" done in {round(time.time() - start, 3)}s.\n") | |
def lookup(self, token: str): | |
"""Looks up the given token in the loaded embeddings.""" | |
key = token | |
if key not in self.data: | |
key = self.strip_outer(token) # Try removing < and > | |
if key not in self.data: | |
key = f"<{token}>" # Try wrapping in < and > | |
if key not in self.data: | |
return None # Give up | |
return self.data[key] # We found it! | |
def strip_outer(self, str: str) -> str: | |
"""Strips < and > from the given input string.""" | |
return removesuffix(removeprefix(str, "<"), ">") | |
def _tokenise(self, str: str): | |
"""Splits the input string into tokens using Keras.""" | |
return tf.keras.preprocessing.text.text_to_word_sequence( | |
self._normalise(str), | |
filters = ", \t\n", | |
lower = True, split = " " | |
) | |
def _normalise(self, str): | |
"""Normalises input text to be suitable to GloVe lookup.""" | |
return normalise_text(str) | |
########################################################################### | |
def word_vector_length(self): | |
"""Returns the length of a single word vector.""" | |
return self.word_length | |
def tweetvision(self, str): | |
""" | |
Convert a string to a list of tokens as the AI will see it. | |
Basically the same as .embeddings(str), but returns the tokens instead of the embeddings. | |
""" | |
result = [] | |
for i, token in enumerate(self._tokenise(str)): | |
if self.lookup(token) is None: | |
continue | |
else: | |
result.append(token) | |
return result | |
def embeddings(self, str, length=-1): | |
""" | |
Converts the given string to a list of word embeddings. | |
str (string): The string to convert to an embedding. | |
length (number): The number of tokens that the returned embedding should have. -1 (the default value) indicates that no length normalisation should be performed. | |
""" | |
result = [] | |
# TODO: Handle out-of-vocabulary words better than just stripping them | |
for i, token in enumerate(self._tokenise(str)): | |
embedding = self.lookup(token) | |
if embedding is None: | |
logger.debug(f"[DEBUG] {token} was none") | |
continue | |
result.append(embedding) | |
# Normalise the embedding length if we're asked to | |
if length > -1: | |
result = result[-length:] | |
shortfall = length - len(result) | |
for _ in range(shortfall): | |
result.append(numpy.zeros(self.word_vector_length())) | |
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import io | |
from pathlib import Path | |
import argparse | |
from loguru import logger | |
import json | |
# import tensorflow as tf | |
from lib.glove.glove import GloVe | |
""" | |
Lifted directly from my main codebase. Some things you'll need to do before you can use this: | |
• The above imports will need to be adjusted (especially the GloVe one) | |
• Download a pre-trained GloVe (https://nlp.stanford.edu/projects/glove/) and extract the ZIP | |
• Install loguru for logging (or just remove all calls) using pip ("sudo pip3 install loguru" for Linux; strip the "sudo" on Windows) | |
""" | |
def main(): | |
"""Main entrypoint.""" | |
parser = argparse.ArgumentParser(description="This program calculates the longest word embedding in a list of tweets.") | |
parser.add_argument("--glove", "-g", help="Filepath to the pretrained GloVe word vectors to load.") | |
parser.add_argument("tweets_jsonl", help="The input tweets jsonl file to scan.") | |
args = parser.parse_args() | |
if not Path(args.tweets_jsonl).is_file(): | |
print("Error: File at '" + args.tweets_jsonl + "' does not exist.") | |
exit(1) | |
############################################################################### | |
glove = GloVe(args.glove) | |
longest = 0 | |
handle = io.open(args.tweets_jsonl, "r") | |
for i, line in enumerate(handle): | |
obj = json.loads(line) | |
result = glove.tweetvision(obj["text"]) | |
if len(result) > longest: | |
longest = len(result) | |
logger.info(f"\n\n\n\nTweet #{i} has length of {longest}:") | |
logger.info("INPUT:") | |
logger.info(obj["text"]) | |
logger.info("\nOUTPUT:") | |
logger.info(result) | |
# print(tf.constant(glove.convert(args.input_string))) | |
if __name__ == "__main__": | |
main() | |
else: | |
print("This script must be run directly. It cannot be imported.") | |
exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Script for preprocessing tweets by Romain Paulus. | |
Translation of Ruby script to create features for GloVe vectors for Twitter data. | |
with small modifications by Jeffrey Pennington | |
with translation to Python by Motoki Wu | |
updated to Python 3 by Phil Pope (@ppope) | |
with bugfixes and improvements by Starbeamrainbowlabs (@sbrl) | |
• Tidy up code with help of linters | |
• Add spaces surrounding punctation and <token_blocks> | |
• Limit runs of whitespace to a single space | |
• Transliterate ’ to ' | |
Ref original Ruby source http://nlp.stanford.edu/projects/glove/preprocess-twitter.rb | |
Ref https://gist.github.com/ppope/0ff9fa359fb850ecf74d061f3072633a | |
""" | |
import sys | |
import re | |
FLAGS = re.MULTILINE | re.DOTALL | |
# CHANGED 2021-11 after 2nd rerun: Handle ’ → ' and putting spaces before / after <allcaps>, [, and ] | |
def hashtag(text): | |
"""Handles hashtags.""" | |
text = text.group() | |
hashtag_body = text[1:] | |
if hashtag_body.isupper(): | |
result = " {} ".format(hashtag_body.lower()) | |
else: | |
# Was re.split(r"(?=[A-Z])", hashtag_body, flags=FLAGS) | |
# ref https://stackoverflow.com/a/2277363/1460422 | |
result = " ".join(["<hashtag>"] + re.findall("^[a-z]+|[A-Z][^A-Z]*", hashtag_body, flags=FLAGS)) | |
return result | |
def allcaps(text): | |
"""Handles all caps text.""" | |
text = text.group() | |
return text.lower() + " <allcaps> " | |
# Convenience function to reduce repetition | |
def re_sub(pattern, repl, text): | |
return re.sub(pattern, repl, text, flags=FLAGS) | |
def normalise(text): | |
""" | |
Preprocesses the given input textto make it suitable for GloVe. | |
This is the main function you want to import. | |
:param str text: The text to normalise. | |
""" | |
# Different regex parts for smiley faces | |
eyes = r"[8:=;]" | |
nose = r"['`\-]?" | |
text = re_sub(r"https?:\/\/\S+\b|www\.(\w+\.)+\S*", " <url> ", text) | |
text = re_sub(r"@\w+", " <user> ", text) | |
text = re_sub(r"{}{}[)dD]+|[)dD]+{}{}".format(eyes, nose, nose, eyes), " <smile> ", text) | |
text = re_sub(r"{}{}p+".format(eyes, nose), " <lolface> ", text) | |
text = re_sub(r"{}{}\(+|\)+{}{}".format(eyes, nose, nose, eyes), " <sadface> ", text) | |
text = re_sub(r"{}{}[\/|l*]".format(eyes, nose), " <neutralface> ", text) | |
text = re_sub(r"/", " / ", text) | |
text = re_sub(r"<3", " <heart> ", text) | |
text = re_sub(r"[-+]?[.\d]*[\d]+[:,.\d]*", " <number> ", text) | |
text = re_sub(r"#\S+", hashtag, text) | |
text = re_sub(r"([!?.]){2,}", r" \1 <repeat> ", text) | |
text = re_sub(r"\b(\S*?)(.)\2{2,}\b", r" \1\2 <elong> ", text) | |
# -- I just don't understand why the Ruby script adds <allcaps> to everything so I limited the selection. | |
# text = re_sub(r"([^a-z0-9()<>'`\-]){2,}", allcaps) | |
text = re_sub(r"([A-Z]){2,}", allcaps, text) | |
# Added by @sbrl | |
text = re_sub(r"’", "'", text) | |
text = re_sub(r"([!?:;.,\[\]])", r" \1 ", text) # Spaces around punctuation | |
text = re_sub(r"\s+", r" ", text) # Limit runs of whitespace to a single space | |
return text.lower() | |
if __name__ == '__main__': | |
_, text = sys.argv | |
if text == "test": | |
text = "I TEST alllll kinds of #hashtags and #HASHTAGS, @mentions and 3000 (http://t.co/dkfjkdf). w/ <3 :) haha!!!!!" | |
tokens = normalise(text) | |
print(tokens) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment