Last active
June 30, 2016 06:25
-
-
Save jnothman/20db5747729f9ee27a5a to your computer and use it in GitHub Desktop.
Reentrant Python wrapper to MaltParser using py4j server
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This loads the MaltParser [1] dependency tagger as a Py4J [2] gateway, and provides input translation from Python. | |
This keeps the MaltParser model in memory, while providing a Python interface. | |
$ javac -cp maltparser-1.8.jar:py4j.jar MaltGateway.java | |
$ java -cp maltparser-1.8.jar:py4j.jar:. MaltGateway engmalt.linear-1.7 & | |
[1] | |
Loading model from engmalt.linear-1.7.mco | |
Gateway Server Started | |
$ python maltparser.py | |
1 I _ PRP PRP _ 2 nsubj _ _ | |
2 eat _ VBP VBP _ 0 null _ _ | |
3 broccoli _ NN NN _ 2 dobj _ _ | |
4 . _ . . _ 2 punct _ _ | |
Time to get some dependency tags... | |
[1] http://www.maltparser.org | |
[2] http://py4j.sourceforge.net |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import py4j.GatewayServer; | |
import java.io.File; | |
import java.net.URL; | |
import org.maltparser.concurrent.ConcurrentMaltParserModel; | |
import org.maltparser.concurrent.ConcurrentMaltParserService; | |
import org.maltparser.concurrent.graph.ConcurrentDependencyGraph; | |
import org.maltparser.concurrent.graph.ConcurrentDependencyNode; | |
public class MaltGateway { | |
ConcurrentDependencyGraph outputGraph = null; | |
ConcurrentMaltParserModel model = null; | |
public MaltGateway(File model_file) { | |
try { | |
URL url = model_file.toURI().toURL(); | |
model = ConcurrentMaltParserService.initializeParserModel(url); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
} | |
private String addedColumnsString(ConcurrentDependencyGraph graph) { | |
// Extract columns 6 to 9 as a string | |
final StringBuilder sb = new StringBuilder(); | |
for (int i = 1; i < graph.nTokenNodes(); i++) { | |
ConcurrentDependencyNode node = graph.getTokenNode(i); | |
sb.append(node.getHeadIndex()); | |
for (int j = 7; j <= 9; j++) { | |
sb.append('\t'); | |
sb.append(node.getLabel(j)); | |
} | |
sb.append('\n'); | |
} | |
return sb.toString(); | |
} | |
public String[] parseMany(String[] sentences) { | |
// TODO: multithreading? | |
String[] out = new String[sentences.length]; | |
for (int i = 0; i < sentences.length; i++) | |
out[i] = addedColumnsString(parse(sentences[i].split("\n"))); | |
return out; | |
} | |
public ConcurrentDependencyGraph parse(String[] tokens) { | |
try { | |
return model.parse(tokens); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
return null; | |
} | |
} | |
public static void main(String[] args) { | |
if (args.length < 1 || args.length > 1) { | |
System.out.println("Expected 1 arg: model path (without .mco)"); | |
System.exit(1); | |
} | |
System.out.println("Loading model from " + args[0] + ".mco"); | |
GatewayServer gatewayServer = new GatewayServer(new MaltGateway(new File(args[0] + ".mco"))); | |
gatewayServer.start(); | |
System.out.println("Gateway Server Started"); | |
} | |
} | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
from py4j.java_gateway import JavaGateway | |
# TODO: MaltGateway need only return columns 6-9: | |
InTuple = namedtuple('InTuple', 'offset form lemma cpostag postag feats') | |
OutTuple = namedtuple('OutTuple', 'offset form lemma cpostag postag feats head deprel phead pdeprel') | |
gateway = JavaGateway() | |
_new_array = gateway.new_array | |
_JString = gateway.jvm.java.lang.String | |
def parse(tokens): | |
ret = parse_strings([u'\t'.join(map(unicode, tok)) for tok in tokens]) | |
return [OutTuple(*(list(tok) + tup.split('\t'))) for tok, tup in zip(tokens, ret.strip().split('\n'))] | |
def parse_strings(tokens): | |
arr = _new_array(_JString, len(tokens)) | |
arr[:] = tokens | |
return unicode(gateway.entry_point.parse(arr)) | |
def parse_many(sentences): | |
arr = _new_array(_JString, len(sentences)) | |
arr[:] = sentences | |
return map(unicode, gateway.entry_point.parseMany(arr)) | |
if __name__ == '__main__': | |
import textwrap | |
sents = [ | |
textwrap.dedent(''' | |
1\tI\t_\tPRP\tPRP\t_ | |
2\teat\t_\tVBP\tVBP\t_ | |
3\tbroccoli\t_\tNN\tNN\t_ | |
4\t.\t_\t.\t.\t_ | |
''').strip(), | |
] | |
for sent, deps in zip(sents, parse_many(sents)): | |
for l1, l2 in zip(sent.strip().split('\n'), deps.strip().split('\n')): | |
print(l1 + '\t' + l2) | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment