forked from lavis-nlp/spert
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial code commit and updated README
- Loading branch information
1 parent
35c6a80
commit 9968fae
Showing
23 changed files
with
3,246 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
migrations | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
# custom | ||
data/ | ||
tmp/ | ||
.idea/ | ||
runs/ | ||
*.state | ||
checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,48 @@ | ||
# SpERT: Span-based Entity and Relation Transformer | ||
PyTorch code for SpERT: Span-based Entity and Relation Transformer | ||
PyTorch code for "SpERT: Span-based Entity and Relation Transformer" | ||
|
||
Work in progress (ETA by mid November) | ||
## Setup | ||
### Requirements | ||
- Required | ||
- Python 3.5+ | ||
- PyTorch 1.1.0+ (tested with version 1.3.1) | ||
- transformers 2.2.0+ (tested with version 2.2.0) | ||
- scikit-learn (tested with version 0.21.3) | ||
- tqdm (tested with version 4.19.5) | ||
- numpy (tested with version 1.17.4) | ||
- Optional | ||
- jinja2 (tested with version 2.10) - if installed, used to export relation extraction examples | ||
- tensorboardX (tested with version 1.6) - if installed, used to save training process to tensorboard | ||
|
||
### Fetch data | ||
Fetch converted (to equal JSON format) datasets (CoNLL04 \[1\], SciERC \[2\] and ADE \[3\]): | ||
``` | ||
bash ./scripts/fetch_datasets.sh | ||
``` | ||
|
||
Fetch model checkpoints (best out of 5 runs for each dataset): | ||
``` | ||
bash ./scripts/fetch_models.sh | ||
``` | ||
|
||
## Examples | ||
Evaluate CoNLL04 on test dataset | ||
``` | ||
python ./spert.py eval --config configs/example_eval.conf | ||
``` | ||
|
||
Train CoNLL04 on train dataset, evaluate on dev dataset | ||
``` | ||
python ./spert.py train --config configs/example_train.conf | ||
``` | ||
|
||
## Notes | ||
- To train SpERT with SciBERT download SciBERT from https://github.com/allenai/scibert (under "PyTorch HuggingFace Models") and set "model_path" and "tokenizer_path" to point to the SciBERT directory. | ||
- You can call "python ./spert.py train --help" or "python ./spert.py eval --help" for a description of training/evaluation arguments | ||
|
||
## References | ||
``` | ||
[1] Dan Roth and Wen-tau Yih, ‘A Linear Programming Formulation forGlobal Inference in Natural Language Tasks’, in Proc. of CoNLL 2004 at HLT-NAACL 2004, pp. 1–8, Boston, Massachusetts, USA, (May 6 -May 7 2004). ACL. | ||
[2] Yi Luan, Luheng He, Mari Ostendorf, and Hannaneh Hajishirzi, ‘Multi-Task Identification of Entities, Relations, and Coreference for Scientific Knowledge Graph Construction’, in Proc. of EMNLP 2018, pp. 3219–3232, Brussels, Belgium, (October-November 2018). ACL. | ||
[3] Harsha Gurulingappa, Abdul Mateen Rajput, Angus Roberts, JulianeFluck, Martin Hofmann-Apitius, and Luca Toldo, ‘Development of a Benchmark Corpus to Support the Automatic Extraction of Drug-related Adverse Effects from Medical Case Reports’, J. of BiomedicalInformatics,45(5), 885–892, (October 2012). | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import argparse | ||
|
||
|
||
def _add_common_args(arg_parser): | ||
arg_parser.add_argument('--config', type=str) | ||
|
||
# Input | ||
arg_parser.add_argument('--types_path', type=str, help="Path to type specifications") | ||
|
||
# Preprocessing | ||
arg_parser.add_argument('--tokenizer_path', type=str, help="Path to tokenizer") | ||
arg_parser.add_argument('--max_span_size', type=int, default=10, help="Maximum size of spans") | ||
arg_parser.add_argument('--lowercase', action='store_true', default=False, | ||
help="If true, input is lowercased during preprocessing") | ||
arg_parser.add_argument('--sampling_processes', type=int, default=4, | ||
help="Number of sampling processes. 0 = no multiprocessing for sampling") | ||
arg_parser.add_argument('--sampling_limit', type=int, default=100, help="Maximum number of sample batches in queue") | ||
|
||
# Logging | ||
arg_parser.add_argument('--label', type=str, help="Label of run. Used as the directory name of logs/models") | ||
arg_parser.add_argument('--log_path', type=str, help="Path do directory where training/evaluation logs are stored") | ||
arg_parser.add_argument('--store_examples', action='store_true', | ||
help="If true, store evaluation examples on disc (in log directory)") | ||
arg_parser.add_argument('--example_count', type=int, default=None, | ||
help="Count of evaluation example to store (if store_examples == True)") | ||
arg_parser.add_argument('--debug', action='store_true', default=False, help="Debugging mode on/off") | ||
|
||
# Model / Training / Evaluation | ||
arg_parser.add_argument('--model_path', type=str, help="Path to directory that contains model checkpoints") | ||
arg_parser.add_argument('--model_type', type=str, default="spert", help="Type of model") | ||
arg_parser.add_argument('--cpu', action='store_true', default=False, | ||
help="If true, train/evaluate on CPU even if a CUDA device is available") | ||
arg_parser.add_argument('--eval_batch_size', type=int, default=1, help="Evaluation batch size") | ||
arg_parser.add_argument('--max_pairs', type=int, default=1000, | ||
help="Maximum entity pairs to process during training/evaluation") | ||
arg_parser.add_argument('--rel_filter_threshold', type=float, default=0.4, help="Filter threshold for relations") | ||
arg_parser.add_argument('--size_embedding', type=int, default=25, help="Dimensionality of size embedding") | ||
arg_parser.add_argument('--prop_drop', type=float, default=0.1, help="Probability of dropout used in SpERT") | ||
arg_parser.add_argument('--freeze_transformer', action='store_true', default=False, help="Freeze BERT weights") | ||
|
||
# Misc | ||
arg_parser.add_argument('--seed', type=int, default=None, help="Seed") | ||
arg_parser.add_argument('--cache_path', type=str, default=None, | ||
help="Path to cache transformer models (for HuggingFace transformers library)") | ||
|
||
|
||
def train_argparser(): | ||
arg_parser = argparse.ArgumentParser() | ||
|
||
# Input | ||
arg_parser.add_argument('--train_path', type=str, help="Path to train dataset") | ||
arg_parser.add_argument('--valid_path', type=str, help="Path to validation dataset") | ||
|
||
# Logging | ||
arg_parser.add_argument('--save_path', type=str, help="Path to directory where model checkpoints are stored") | ||
arg_parser.add_argument('--init_eval', action='store_true', default=False, | ||
help="If true, evaluate validation set before training") | ||
arg_parser.add_argument('--save_optimizer', action='store_true', default=False, | ||
help="Save optimizer alongside model") | ||
arg_parser.add_argument('--train_log_iter', type=int, default=1, help="Log training process every x iterations") | ||
arg_parser.add_argument('--final_eval', action='store_true', default=False, | ||
help="Evaluate the model only after training, not at every epoch") | ||
|
||
# Model / Training | ||
arg_parser.add_argument('--train_batch_size', type=int, default=2, help="Training batch size") | ||
arg_parser.add_argument('--epochs', type=int, default=20, help="Number of epochs") | ||
arg_parser.add_argument('--neg_entity_count', type=int, default=100, | ||
help="Number of negative entity samples per document (sentence)") | ||
arg_parser.add_argument('--neg_relation_count', type=int, default=100, | ||
help="Number of negative relation samples per document (sentence)") | ||
arg_parser.add_argument('--lr', type=float, default=5e-5, help="Learning rate") | ||
arg_parser.add_argument('--lr_warmup', type=float, default=0.1, | ||
help="Proportion of total train iterations to warmup in linear increase/decrease schedule") | ||
arg_parser.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay to apply") | ||
arg_parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm") | ||
|
||
_add_common_args(arg_parser) | ||
|
||
return arg_parser | ||
|
||
|
||
def eval_argparser(): | ||
arg_parser = argparse.ArgumentParser() | ||
|
||
# Input | ||
arg_parser.add_argument('--dataset_path', type=str, help="Path to dataset") | ||
|
||
_add_common_args(arg_parser) | ||
|
||
return arg_parser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import copy | ||
import multiprocessing as mp | ||
|
||
|
||
def process_configs(target, arg_parser): | ||
args, _ = arg_parser.parse_known_args() | ||
ctx = mp.get_context('spawn') | ||
|
||
for run_args, _run_config, _run_repeat in _yield_configs(arg_parser, args): | ||
p = ctx.Process(target=target, args=(run_args,)) | ||
p.start() | ||
p.join() | ||
|
||
|
||
def _read_config(path): | ||
lines = open(path).readlines() | ||
|
||
runs = [] | ||
run = [1, dict()] | ||
for line in lines: | ||
stripped_line = line.strip() | ||
|
||
# continue in case of comment | ||
if stripped_line.startswith('#'): | ||
continue | ||
|
||
if not stripped_line: | ||
if run[1]: | ||
runs.append(run) | ||
|
||
run = [1, dict()] | ||
continue | ||
|
||
if stripped_line.startswith('[') and stripped_line.endswith(']'): | ||
repeat = int(stripped_line[1:-1]) | ||
run[0] = repeat | ||
else: | ||
key, value = stripped_line.split('=') | ||
key, value = (key.strip(), value.strip()) | ||
run[1][key] = value | ||
|
||
if run[1]: | ||
runs.append(run) | ||
|
||
return runs | ||
|
||
|
||
def _convert_config(config): | ||
config_list = [] | ||
for k, v in config.items(): | ||
if v.lower() == 'true': | ||
config_list.append('--' + k) | ||
elif v.lower() != 'false': | ||
config_list.extend(['--' + k] + v.split(' ')) | ||
|
||
return config_list | ||
|
||
|
||
def _yield_configs(arg_parser, args, verbose=True): | ||
_print = (lambda x: print(x)) if verbose else lambda x: x | ||
|
||
if args.config: | ||
config = _read_config(args.config) | ||
|
||
for run_repeat, run_config in config: | ||
print("-" * 50) | ||
print("Config:") | ||
print(run_config) | ||
|
||
args_copy = copy.deepcopy(args) | ||
config_list = _convert_config(run_config) | ||
run_args = arg_parser.parse_args(config_list, namespace=args_copy) | ||
run_args_dict = vars(run_args) | ||
|
||
# set boolean values | ||
for k, v in run_config.items(): | ||
if v.lower() == 'false': | ||
run_args_dict[k] = False | ||
|
||
print("Repeat %s times" % run_repeat) | ||
print("-" * 50) | ||
|
||
for iteration in range(run_repeat): | ||
_print("Iteration %s" % iteration) | ||
_print("-" * 50) | ||
|
||
yield run_args, run_config, run_repeat | ||
|
||
else: | ||
yield args, None, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[1] | ||
label = conll04_eval | ||
model_type = spert | ||
model_path = data/models/conll04 | ||
tokenizer_path = data/models/conll04 | ||
dataset_path = data/datasets/conll04/conll04_test.json | ||
types_path = data/datasets/conll04/conll04_types.json | ||
eval_batch_size = 1 | ||
rel_filter_threshold = 0.4 | ||
size_embedding = 25 | ||
prop_drop = 0.1 | ||
max_span_size = 10 | ||
store_examples = true | ||
sampling_processes = 4 | ||
sampling_limit = 100 | ||
max_pairs = 1000 | ||
log_path = data/log/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
[1] | ||
label = conll04_train | ||
model_type = spert | ||
model_path = bert-base-cased | ||
tokenizer_path = bert-base-cased | ||
train_path = data/datasets/conll04/conll04_train.json | ||
valid_path = data/datasets/conll04/conll04_dev.json | ||
types_path = data/datasets/conll04/conll04_types.json | ||
train_batch_size = 2 | ||
eval_batch_size = 1 | ||
neg_entity_count = 100 | ||
neg_relation_count = 100 | ||
epochs = 20 | ||
lr = 5e-5 | ||
lr_warmup = 0.1 | ||
weight_decay = 0.01 | ||
max_grad_norm = 1.0 | ||
rel_filter_threshold = 0.4 | ||
size_embedding = 25 | ||
prop_drop = 0.1 | ||
max_span_size = 10 | ||
store_examples = true | ||
sampling_processes = 4 | ||
sampling_limit = 100 | ||
max_pairs = 1000 | ||
final_eval = true | ||
log_path = data/log/ | ||
save_path = data/save/ |
Oops, something went wrong.