forked from facebookresearch/PyTorch-BigGraph
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfb15k.py
101 lines (86 loc) · 3.31 KB
/
fb15k.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE.txt file in the root directory of this source tree.
import argparse
from pathlib import Path
import attr
import pkg_resources
from torchbiggraph.config import ConfigFileLoader, add_to_sys_path
from torchbiggraph.converters.importers import TSVEdgelistReader, convert_input_data
from torchbiggraph.converters.utils import download_url, extract_tar
from torchbiggraph.eval import do_eval
from torchbiggraph.filtered_eval import FilteredRankingEvaluator
from torchbiggraph.train import train
from torchbiggraph.util import (
SubprocessInitializer,
set_logging_verbosity,
setup_logging,
)
FB15K_URL = "https://dl.fbaipublicfiles.com/starspace/fb15k.tgz"
FILENAMES = [
"FB15k/freebase_mtr100_mte100-train.txt",
"FB15k/freebase_mtr100_mte100-valid.txt",
"FB15k/freebase_mtr100_mte100-test.txt",
]
# Figure out the path where the sample config was installed by the package manager.
# This can be overridden with --config.
DEFAULT_CONFIG = pkg_resources.resource_filename(
"torchbiggraph.examples", "configs/fb15k_config.py"
)
def main():
setup_logging()
parser = argparse.ArgumentParser(description="Example on FB15k")
parser.add_argument("--config", default=DEFAULT_CONFIG, help="Path to config file")
parser.add_argument("-p", "--param", action="append", nargs="*")
parser.add_argument(
"--data_dir", type=Path, default="data", help="where to save processed data"
)
parser.add_argument(
"--no-filtered",
dest="filtered",
action="store_false",
help="Run unfiltered eval",
)
args = parser.parse_args()
# download data
data_dir = args.data_dir
fpath = download_url(FB15K_URL, data_dir)
extract_tar(fpath)
print("Downloaded and extracted file.")
loader = ConfigFileLoader()
config = loader.load_config(args.config, args.param)
set_logging_verbosity(config.verbose)
subprocess_init = SubprocessInitializer()
subprocess_init.register(setup_logging, config.verbose)
subprocess_init.register(add_to_sys_path, loader.config_dir.name)
input_edge_paths = [data_dir / name for name in FILENAMES]
output_train_path, output_valid_path, output_test_path = config.edge_paths
convert_input_data(
config.entities,
config.relations,
config.entity_path,
config.edge_paths,
input_edge_paths,
TSVEdgelistReader(lhs_col=0, rhs_col=2, rel_col=1),
dynamic_relations=config.dynamic_relations,
)
train_config = attr.evolve(config, edge_paths=[output_train_path])
train(train_config, subprocess_init=subprocess_init)
relations = [attr.evolve(r, all_negs=True) for r in config.relations]
eval_config = attr.evolve(
config, edge_paths=[output_test_path], relations=relations, num_uniform_negs=0
)
if args.filtered:
filter_paths = [output_test_path, output_valid_path, output_train_path]
do_eval(
eval_config,
evaluator=FilteredRankingEvaluator(eval_config, filter_paths),
subprocess_init=subprocess_init,
)
else:
do_eval(eval_config, subprocess_init=subprocess_init)
if __name__ == "__main__":
main()