-
-
Save newmedia2/f0912038d1046ce79a27f544074b0dce to your computer and use it in GitHub Desktop.
Scripts that were used in the "Video Games with Sense2Vec" tutorial found here: https://youtu.be/chLZ6g4t3VA.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script combines two datasets to generate a file with all found patterns. | |
""" | |
import srsly | |
from prodigy.components.db import connect | |
import spacy | |
nlp = spacy.blank("en") | |
def text_to_patterns(text): | |
return [{"lower": t.text.lower()} for t in nlp(text)] | |
db = connect() | |
# These are the patterns from the custom approach | |
dataset = db.get_dataset("more-video-terms") | |
first_set = [{"label": "GAME", "pattern": text_to_patterns(d['text'])} for d in dataset if d['answer'] == 'accept'] | |
# These are the patterns from the sense2vec.teach recipe | |
dataset = db.get_dataset("video-game-terms") | |
second_set = [{"label": "GAME", "pattern": text_to_patterns(d['word'])} for d in dataset if d['answer'] == 'accept'] | |
srsly.write_jsonl("all-patterns.jsonl", first_set + second_set) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script uses sense2vec to find potentially relevant phrases in your training data. | |
""" | |
import tqdm | |
import spacy | |
import typer | |
import srsly | |
from sklearn.metrics.pairwise import cosine_distances | |
from rich.console import Console | |
def scan_file(jsonl_file, s2v_path, queries, out_file): | |
console = Console() | |
# Initialize all components | |
console.log("Loading spaCy model.") | |
nlp = spacy.load("en_core_web_sm") | |
console.log("Loading s2v model.") | |
s2v_component = nlp.add_pipe("sense2vec") | |
s2v_component.from_disk(s2v_path) | |
# More setup | |
console.log("Reading jsonl data.") | |
blob = list(srsly.read_jsonl(jsonl_file)) | |
queries = queries.split(",") | |
# Things to keep track of | |
terms = set() | |
not_terms = set() | |
distances = [] | |
# Construct generator | |
g = nlp.pipe((e["text"] for e in blob), batch_size=50) | |
console.log("Staring big loop.") | |
for doc in tqdm.tqdm(g, total=len(blob)): | |
for phrase in doc._.s2v_phrases: | |
for query in queries: | |
terms, not_terms, distances = handle_phrase( | |
phrase, query, terms, not_terms, distances, s2v_component | |
) | |
# Write terms into file | |
output = [{"text": t, "meta": {"dist": float(d)}} for t, d in zip(terms, distances)] | |
srsly.write_jsonl(out_file, output) | |
console.log(f"Write written in {out_file}.") | |
def handle_phrase(phrase, query, terms, not_terms, distances, s2v_component): | |
# First check if phrase needs to be considered | |
best_key = s2v_component.s2v.get_best_sense(query) | |
p1 = phrase.text.lower() not in terms | |
p2 = phrase.text.lower() not in not_terms | |
if p1 & p2: | |
# Phrase needs to be considered | |
v1 = phrase[:5]._.s2v_vec | |
v2 = s2v_component.s2v[best_key] | |
if v1 is not None: | |
dist = cosine_distances([v1], [v2])[0, 0] | |
if dist < 0.6: | |
terms = terms | {phrase.text.lower()} | |
distances.append(dist) | |
if p2: | |
# Phrase won't need to be considered | |
not_terms = not_terms | {phrase.text.lower()} | |
return terms, not_terms, distances | |
if __name__ == "__main__": | |
typer.run(scan_file) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is an optional script that doesn't appear in the video, but it creates a subset | |
of the original data such that we only have examples where at least one pattern matches. | |
""" | |
import json | |
import spacy | |
from spacy.matcher import Matcher | |
nlp = spacy.blank("en") | |
matcher = Matcher(nlp.vocab) | |
import srsly | |
patterns = srsly.read_jsonl("all-patterns.jsonl") | |
for p in patterns: | |
matcher.add(p["label"], [p["pattern"]]) | |
texts = srsly.read_jsonl("xbox-support.jsonl") | |
for doc in nlp.pipe(t['text'] for t in texts): | |
matches = matcher(doc) | |
if matches: | |
print(json.dumps({"text": doc.text})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment