Last active
June 23, 2019 08:35
-
-
Save simonepri/f150224b918b6145e3e1e02574c5c3c7 to your computer and use it in GitHub Desktop.
pbg-tools
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict | |
import os | |
import sys | |
import random | |
import json | |
import attr | |
from torchbiggraph.config import ConfigSchema | |
from torchbiggraph.schema import DeepTypeError | |
from torchbiggraph.converters.import_from_tsv import convert_input_data | |
from torchbiggraph.converters.export_to_tsv import make_tsv | |
from torchbiggraph.eval import do_eval | |
from torchbiggraph.filtered_eval import FilteredRankingEvaluator | |
from torchbiggraph.train import train | |
def random_split_file( | |
filenames: Dict[str, str], | |
from_line: int = 0, | |
to_line: int = -1, | |
train_split: float = 0.75, | |
valid_split: float = 0.10, | |
overwrite: bool = False | |
): | |
edgelist_file = filenames['edgelist'] | |
train_file = filenames['train'] | |
valid_file = filenames['valid'] | |
test_file = filenames['test'] | |
output_paths = [train_file, valid_file, test_file] | |
if not overwrite and all(os.path.exists(path) for path in output_paths): | |
print('Found some files that indicate that the input data ' | |
'has already been shuffled and split, not doing it again.') | |
print('These files are: %s' % ', '.join(output_paths)) | |
return | |
if not os.path.exists(edgelist_file): | |
print('The edgelist file does not exists ') | |
print('The path provided was: %s' % edgelist_file) | |
sys.exit(1) | |
return | |
print('Shuffling and splitting train/test file. This may take a while.') | |
print('Reading data from file: %s' % filenames['edgelist']) | |
with open(filenames['edgelist'], 'rt') as in_tf: | |
lines = in_tf.readlines() | |
if from_line != 0 or to_line != -1: | |
lines = lines[from_line:to_line] | |
print('Shuffling data') | |
random.shuffle(lines) | |
train_split_len = int(len(lines) * train_split) | |
valid_split_len = int((len(lines) - train_split_len) * valid_split) | |
print('Splitting to train, validation and test files') | |
with open(train_file, 'wt') as out_tf_train: | |
for line in lines[:train_split_len]: | |
out_tf_train.write(line) | |
with open(valid_file, 'wt') as out_tf_valid: | |
for line in lines[train_split_len:train_split_len+valid_split_len]: | |
out_tf_valid.write(line) | |
with open(test_file, 'wt') as out_tf_test: | |
for line in lines[train_split_len+valid_split_len:]: | |
out_tf_test.write(line) | |
print('Total examples: %d' % len(lines)) | |
print('Train examples: %d' % train_split_len) | |
print('Valid examples: %d' % valid_split_len) | |
print('Test examples: %d' % (len(lines) - train_split_len - valid_split_len)) | |
def convert_path( | |
fname: str | |
) -> str: | |
basename, _ = os.path.splitext(fname) | |
out_dir = basename + '_partitioned' | |
return out_dir | |
def run_training(config: ConfigSchema, edges_paths: Dict[str, str], filtered: bool = False): | |
train_path = [convert_path(edges_paths['train'])] | |
train_config = attr.evolve(config, edge_paths=train_path) | |
train(train_config) | |
def run_evaluation( | |
config: ConfigSchema, | |
edges_paths: Dict[str, str], | |
filtered: bool = False, | |
all_negs: bool = True | |
): | |
eval_path = [convert_path(edges_paths['test'])] | |
if all_negs: | |
relations = [attr.evolve(relation, all_negs=all_negs) for relation in config.relations] | |
eval_config = attr.evolve(config, edge_paths=eval_path, relations=relations, num_uniform_negs=0) | |
else: | |
eval_config = attr.evolve(config, edge_paths=eval_path) | |
if filtered: | |
filter_paths = [ | |
convert_path(edges_paths['test']), | |
convert_path(edges_paths['valid']), | |
convert_path(edges_paths['train']), | |
] | |
do_eval(eval_config, evaluator=FilteredRankingEvaluator(eval_config, filter_paths)) | |
else: | |
do_eval(eval_config) | |
def parse_config( | |
config_dict: Dict | |
) -> ConfigSchema: | |
try: | |
config = ConfigSchema.from_dict(config_dict) | |
except DeepTypeError as err: | |
print("Error in the configuration file, aborting.", file=sys.stderr) | |
print(str(err), file=sys.stderr) | |
sys.exit(1) | |
return config | |
def input_from_tsv( | |
config: ConfigSchema, | |
edges_paths: Dict[str, str], | |
cols: Dict[str, int] | |
): | |
random_split_file(edges_paths, train_split=0.95, valid_split=0.45) | |
convert_input_data( | |
config.entities, | |
config.relations, | |
config.entity_path, | |
[edges_paths['test'], edges_paths['valid'], edges_paths['train']], | |
lhs_col=cols['lhs'], | |
rel_col=cols['rel'], | |
rhs_col=cols['rhs'], | |
) | |
def output_to_tsv( | |
config: ConfigSchema, | |
embs_paths: Dict[str, str] | |
): | |
dict_path = os.path.join(config.entity_path, 'dictionary.json') | |
with open(dict_path, "rt") as dict_file: | |
dump = json.load(dict_file) | |
with open(embs_paths['ent'], "wt+") as ent_emb_file, open(embs_paths['rel'], "wt+") as rel_emb_file: | |
make_tsv(config.checkpoint_path, dump["relations"], dump["entities"], ent_emb_file, rel_emb_file) | |
def run_pbg( | |
config_dict: Dict, | |
edges_paths: Dict[str, str], | |
embs_paths: Dict[str, str], | |
run_train: bool = True, | |
run_eval: bool = True, | |
filtered: bool = False, | |
all_negs: bool = True, | |
cols: Dict[str, int] = {'lhs': 0, 'rel': 1, 'rhs': 2} | |
): | |
config = parse_config(config_dict) | |
input_from_tsv(config, edges_paths, cols) | |
if run_train: | |
run_training(config, edges_paths) | |
if run_eval: | |
run_evaluation(config, edges_paths, filtered=filtered, all_negs=all_negs) | |
output_to_tsv(config, embs_paths) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment