Created
December 8, 2018 02:55
-
-
Save icoxfog417/1c9c2873cf4a5c2360dd3d342f396f99 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import spacy | |
class DependencyGraph(): | |
def __init__(self, lang, vocabulary): | |
self.lang = lang | |
self._parser = spacy.load(lang, disable=["ner", "textcat"]) | |
self.vocabulary = vocabulary | |
def build(self, sequence, size=-1): | |
words = self.vocabulary.inverse(sequence) | |
sentence = " ".join(words) # have to consider non-space-separated lang | |
_size = size if size > 0 else len(sequence) | |
matrix = np.zeros((_size, _size)) | |
tokens = self._parser(sentence) | |
for token in tokens: | |
# print("{} =({})=> {}".format(token.text, token.dep_, token.head.text)) | |
if token.i < _size and token.head.i < _size: | |
matrix[token.i, token.head.i] = 1 | |
return matrix | |
def batch_build(self, sequences, size=-1): | |
matrices = [self.build(s, size) for s in sequences] | |
return np.array(matrices) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment