Created
October 30, 2015 15:46
-
-
Save jacquerie/bdf91218c8561b23aa06 to your computer and use it in GitHub Desktop.
Python-driven GROBID retraining
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf8 -*- | |
import os | |
import grobid_core | |
import grobid_trainer | |
if __name__ == '__main__': | |
current_dir = os.path.dirname(os.path.realpath(__file__)) | |
grobid_home = os.path.join(current_dir, 'grobid', 'grobid-home') | |
grobid_properties = os.path.join(grobid_home, 'config', 'grobid.properties') | |
grobid_input = os.path.join(current_dir, 'grobid', 'input') | |
grobid_output = os.path.join(current_dir, 'grobid', 'output') | |
classpath_core = os.path.join(current_dir, 'grobid', 'grobid-core', | |
'target', 'grobid-core-0.3.9-SNAPSHOT.one-jar.jar') | |
classpath_trainer = os.path.join(current_dir, 'grobid', | |
'grobid-trainer', 'target', | |
'grobid-trainer-0.3.9-SNAPSHOT.one-jar.jar') | |
grobid_core = grobid_core.GROBIDCore(classpath_core, grobid_home, | |
grobid_properties, grobid_input, | |
grobid_output) | |
grobid_core.create_training_reference_segmentation() | |
grobid_trainer = grobid_trainer.GROBIDTrainer(classpath_trainer, | |
grobid_home) | |
grobid_trainer.train('segmentation') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import subprocess32 | |
class GROBIDCore(): | |
"""Wrapper class for GROBID core.""" | |
def __init__(self, classpath, grobid_home, grobid_properties, | |
grobid_input, grobid_output): | |
self.classpath = classpath | |
self.grobid_home = grobid_home | |
self.grobid_properties = grobid_properties | |
self.grobid_input = grobid_input | |
self.grobid_output = grobid_output | |
def _call(self, method): | |
"""TODO.""" | |
subprocess32.call(['java', '-Xmx1024m', | |
'-jar', self.classpath, | |
'-gH', self.grobid_home, | |
'-gP', self.grobid_properties, | |
'-dIn', self.grobid_input, | |
'-dOut', self.grobid_output, | |
'-exe', method]) | |
def process_header(self): | |
"""TODO.""" | |
self._call('processHeader') | |
def process_full_text(self): | |
"""TODO.""" | |
pass | |
def process_date(self): | |
"""TODO.""" | |
pass | |
def process_authors_header(self): | |
"""TODO.""" | |
pass | |
def process_authors_citation(self): | |
"""TODO.""" | |
pass | |
def process_affiliations(self): | |
"""TODO.""" | |
pass | |
def process_raw_reference(self): | |
"""TODO.""" | |
self._call('processRawReference') | |
def process_references(self): | |
"""TODO.""" | |
self._call('processReferences') | |
def create_training_reference_segmentation(self): | |
"""TODO.""" | |
self._call('createTrainingReferenceSegmentation') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import subprocess32 | |
class GROBIDInvalidModel(RuntimeError): | |
pass | |
class GROBIDInvalidSplit(RuntimeError): | |
pass | |
class GROBIDTrainer(object): | |
"""Wrapper class for calling GROBID trainer.""" | |
def __init__(self, classpath, grobid_home): | |
self.classpath = classpath | |
self.grobid_home = grobid_home | |
self.MODE = {'TRAIN': '0', 'EVAL': '1', 'TRAIN_AND_EVAL': '2'} | |
self.MODELS = [ | |
'header', | |
'segmentation' | |
] | |
def train(self, model): | |
"""Wrapper for training a model.""" | |
if model not in self.MODELS: | |
raise GROBIDInvalidModel( | |
'%s is not a valid GROBID model.' % model | |
) | |
subprocess32.call(['java', '-Xmx1024m', '-jar', self.classpath, | |
self.MODE['TRAIN'], model, '-gH', self.grobid_home]) | |
def eval(self, model): | |
"""Wrapper for evaluating a model.""" | |
if model not in self.MODELS: | |
raise GROBIDInvalidModel( | |
'%s is not a valid GROBID model.' % model | |
) | |
subprocess32.call(['java', '-Xmx1024m', '-jar', self.classpath, | |
self.MODE['EVAL'], model, '-gH', self.grobid_home]) | |
def train_and_eval(self, model, split): | |
"""Wrapper for training and evaluating a model, given a split.""" | |
if model not in self.MODELS: | |
raise GROBIDInvalidModel( | |
'%s is not a valid GROBID model.' % model | |
) | |
if split < 0 or split > 1: | |
raise GROBIDInvalidSplit( | |
'%f must be between 0 and 1.' % split | |
) | |
subprocess32.call(['java', '-Xmx1024m', '-jar', self.classpath, | |
self.MODE['TRAIN_AND_EVAL'], model, '-gH', | |
self.grobid_home, '-s', split]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment