This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| edge_paths = [os.path.join(DATA_DIR, name) for name in FILENAMES.values()] | |
| from torchbiggraph.converters.import_from_tsv import convert_input_data | |
| convert_input_data( | |
| CONFIG_PATH, | |
| edge_paths, | |
| lhs_col=0, | |
| rhs_col=1, | |
| rel_col=None, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from torchbiggraph.config import parse_config | |
| import attr | |
| train_config = parse_config(CONFIG_PATH) | |
| train_path = [convert_path(os.path.join(DATA_DIR, FILENAMES['train']))] | |
| train_config = attr.evolve(train_config, edge_paths=train_path) | |
| from torchbiggraph.train import train | |
| train(train_config) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from torchbiggraph.eval import do_eval | |
| eval_path = [convert_path(os.path.join(DATA_DIR, FILENAMES['test']))] | |
| eval_config = attr.evolve(train_config, edge_paths=eval_path) | |
| do_eval(eval_config) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import h5py | |
| with open(os.path.join(DATA_DIR,"dictionary.json"), "rt") as tf: | |
| dictionary = json.load(tf) | |
| user_id = "0" | |
| offset = dictionary["entities"]["user_id"].index(user_id) | |
| print("our offset for user_id " , user_id, " is: ", offset) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import random | |
| """ | |
| adapted from https://github.com/facebookresearch/PyTorch-BigGraph/blob/master/torchbiggraph/examples/livejournal.py | |
| """ | |
| FILENAMES = { | |
| 'train': 'train.txt', | |
| 'test': 'test.txt', | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| print("Now let's do some simple things within torch:") | |
| from torchbiggraph.model import DotComparator | |
| src_entity_offset = dictionary["entities"]["user_id"].index("0") # France | |
| dest_1_entity_offset = dictionary["entities"]["user_id"].index("7") # Paris | |
| dest_2_entity_offset = dictionary["entities"]["user_id"].index("1") # Paris | |
| rel_type_index = dictionary["relations"].index("follow") # note we only have one... | |
| with h5py.File("model/example_2/embeddings_user_id_0.v10.h5", "r") as hf: | |
| src_embedding = hf["embeddings"][src_entity_offset, :] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| print("finally, let's do some ranking...") | |
| entity_count = 8 | |
| scores, _, _ = comparator( | |
| comparator.prepare(torch.tensor(src_embedding.reshape([1,1,10]))).expand(1, entity_count, 10), | |
| comparator.prepare(torch.tensor(dest_embeddings.reshape([1,8,10]))), | |
| torch.empty(1, 0, 10), # Left-hand side negatives, not needed | |
| torch.empty(1, 0, 10), # Right-hand side negatives, not needed | |
| ) | |
| permutation = scores.flatten().argsort(descending=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| containers: | |
| build-env: | |
| image: python:3.7 | |
| working_directory: /src | |
| environment: | |
| PYTHONPATH: "/src" | |
| run_as_current_user: | |
| enabled: true | |
| home_directory: /home/container-user |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| containers: | |
| build-env: | |
| image: python:3.7 | |
| volumes: | |
| - local: . | |
| container: /src | |
| options: cached | |
| - local: .pip-cache | |
| container: /src/.pip-cache | |
| options: cached |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| containers: | |
| build-env: | |
| image: python:3.7 | |
| volumes: | |
| - local: . | |
| container: /src | |
| options: cached | |
| - local: .pip-cache | |
| container: /src/.pip-cache | |
| options: cached |