Skip to content

Commit

Permalink
initial code commit and updated README
Browse files Browse the repository at this point in the history
  • Loading branch information
markus-eberts committed Nov 27, 2019
1 parent 35c6a80 commit 9968fae
Show file tree
Hide file tree
Showing 23 changed files with 3,246 additions and 2 deletions.
113 changes: 113 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
migrations

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# custom
data/
tmp/
.idea/
runs/
*.state
checkpoint
48 changes: 46 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,48 @@
# SpERT: Span-based Entity and Relation Transformer
PyTorch code for SpERT: Span-based Entity and Relation Transformer
PyTorch code for "SpERT: Span-based Entity and Relation Transformer"

Work in progress (ETA by mid November)
## Setup
### Requirements
- Required
- Python 3.5+
- PyTorch 1.1.0+ (tested with version 1.3.1)
- transformers 2.2.0+ (tested with version 2.2.0)
- scikit-learn (tested with version 0.21.3)
- tqdm (tested with version 4.19.5)
- numpy (tested with version 1.17.4)
- Optional
- jinja2 (tested with version 2.10) - if installed, used to export relation extraction examples
- tensorboardX (tested with version 1.6) - if installed, used to save training process to tensorboard

### Fetch data
Fetch converted (to equal JSON format) datasets (CoNLL04 \[1\], SciERC \[2\] and ADE \[3\]):
```
bash ./scripts/fetch_datasets.sh
```

Fetch model checkpoints (best out of 5 runs for each dataset):
```
bash ./scripts/fetch_models.sh
```

## Examples
Evaluate CoNLL04 on test dataset
```
python ./spert.py eval --config configs/example_eval.conf
```

Train CoNLL04 on train dataset, evaluate on dev dataset
```
python ./spert.py train --config configs/example_train.conf
```

## Notes
- To train SpERT with SciBERT download SciBERT from https://github.com/allenai/scibert (under "PyTorch HuggingFace Models") and set "model_path" and "tokenizer_path" to point to the SciBERT directory.
- You can call "python ./spert.py train --help" or "python ./spert.py eval --help" for a description of training/evaluation arguments

## References
```
[1] Dan Roth and Wen-tau Yih, ‘A Linear Programming Formulation forGlobal Inference in Natural Language Tasks’, in Proc. of CoNLL 2004 at HLT-NAACL 2004, pp. 1–8, Boston, Massachusetts, USA, (May 6 -May 7 2004). ACL.
[2] Yi Luan, Luheng He, Mari Ostendorf, and Hannaneh Hajishirzi, ‘Multi-Task Identification of Entities, Relations, and Coreference for Scientific Knowledge Graph Construction’, in Proc. of EMNLP 2018, pp. 3219–3232, Brussels, Belgium, (October-November 2018). ACL.
[3] Harsha Gurulingappa, Abdul Mateen Rajput, Angus Roberts, JulianeFluck, Martin Hofmann-Apitius, and Luca Toldo, ‘Development of a Benchmark Corpus to Support the Automatic Extraction of Drug-related Adverse Effects from Medical Case Reports’, J. of BiomedicalInformatics,45(5), 885–892, (October 2012).
```
Empty file added __init__.py
Empty file.
90 changes: 90 additions & 0 deletions args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import argparse


def _add_common_args(arg_parser):
arg_parser.add_argument('--config', type=str)

# Input
arg_parser.add_argument('--types_path', type=str, help="Path to type specifications")

# Preprocessing
arg_parser.add_argument('--tokenizer_path', type=str, help="Path to tokenizer")
arg_parser.add_argument('--max_span_size', type=int, default=10, help="Maximum size of spans")
arg_parser.add_argument('--lowercase', action='store_true', default=False,
help="If true, input is lowercased during preprocessing")
arg_parser.add_argument('--sampling_processes', type=int, default=4,
help="Number of sampling processes. 0 = no multiprocessing for sampling")
arg_parser.add_argument('--sampling_limit', type=int, default=100, help="Maximum number of sample batches in queue")

# Logging
arg_parser.add_argument('--label', type=str, help="Label of run. Used as the directory name of logs/models")
arg_parser.add_argument('--log_path', type=str, help="Path do directory where training/evaluation logs are stored")
arg_parser.add_argument('--store_examples', action='store_true',
help="If true, store evaluation examples on disc (in log directory)")
arg_parser.add_argument('--example_count', type=int, default=None,
help="Count of evaluation example to store (if store_examples == True)")
arg_parser.add_argument('--debug', action='store_true', default=False, help="Debugging mode on/off")

# Model / Training / Evaluation
arg_parser.add_argument('--model_path', type=str, help="Path to directory that contains model checkpoints")
arg_parser.add_argument('--model_type', type=str, default="spert", help="Type of model")
arg_parser.add_argument('--cpu', action='store_true', default=False,
help="If true, train/evaluate on CPU even if a CUDA device is available")
arg_parser.add_argument('--eval_batch_size', type=int, default=1, help="Evaluation batch size")
arg_parser.add_argument('--max_pairs', type=int, default=1000,
help="Maximum entity pairs to process during training/evaluation")
arg_parser.add_argument('--rel_filter_threshold', type=float, default=0.4, help="Filter threshold for relations")
arg_parser.add_argument('--size_embedding', type=int, default=25, help="Dimensionality of size embedding")
arg_parser.add_argument('--prop_drop', type=float, default=0.1, help="Probability of dropout used in SpERT")
arg_parser.add_argument('--freeze_transformer', action='store_true', default=False, help="Freeze BERT weights")

# Misc
arg_parser.add_argument('--seed', type=int, default=None, help="Seed")
arg_parser.add_argument('--cache_path', type=str, default=None,
help="Path to cache transformer models (for HuggingFace transformers library)")


def train_argparser():
arg_parser = argparse.ArgumentParser()

# Input
arg_parser.add_argument('--train_path', type=str, help="Path to train dataset")
arg_parser.add_argument('--valid_path', type=str, help="Path to validation dataset")

# Logging
arg_parser.add_argument('--save_path', type=str, help="Path to directory where model checkpoints are stored")
arg_parser.add_argument('--init_eval', action='store_true', default=False,
help="If true, evaluate validation set before training")
arg_parser.add_argument('--save_optimizer', action='store_true', default=False,
help="Save optimizer alongside model")
arg_parser.add_argument('--train_log_iter', type=int, default=1, help="Log training process every x iterations")
arg_parser.add_argument('--final_eval', action='store_true', default=False,
help="Evaluate the model only after training, not at every epoch")

# Model / Training
arg_parser.add_argument('--train_batch_size', type=int, default=2, help="Training batch size")
arg_parser.add_argument('--epochs', type=int, default=20, help="Number of epochs")
arg_parser.add_argument('--neg_entity_count', type=int, default=100,
help="Number of negative entity samples per document (sentence)")
arg_parser.add_argument('--neg_relation_count', type=int, default=100,
help="Number of negative relation samples per document (sentence)")
arg_parser.add_argument('--lr', type=float, default=5e-5, help="Learning rate")
arg_parser.add_argument('--lr_warmup', type=float, default=0.1,
help="Proportion of total train iterations to warmup in linear increase/decrease schedule")
arg_parser.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay to apply")
arg_parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm")

_add_common_args(arg_parser)

return arg_parser


def eval_argparser():
arg_parser = argparse.ArgumentParser()

# Input
arg_parser.add_argument('--dataset_path', type=str, help="Path to dataset")

_add_common_args(arg_parser)

return arg_parser
90 changes: 90 additions & 0 deletions config_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import copy
import multiprocessing as mp


def process_configs(target, arg_parser):
args, _ = arg_parser.parse_known_args()
ctx = mp.get_context('spawn')

for run_args, _run_config, _run_repeat in _yield_configs(arg_parser, args):
p = ctx.Process(target=target, args=(run_args,))
p.start()
p.join()


def _read_config(path):
lines = open(path).readlines()

runs = []
run = [1, dict()]
for line in lines:
stripped_line = line.strip()

# continue in case of comment
if stripped_line.startswith('#'):
continue

if not stripped_line:
if run[1]:
runs.append(run)

run = [1, dict()]
continue

if stripped_line.startswith('[') and stripped_line.endswith(']'):
repeat = int(stripped_line[1:-1])
run[0] = repeat
else:
key, value = stripped_line.split('=')
key, value = (key.strip(), value.strip())
run[1][key] = value

if run[1]:
runs.append(run)

return runs


def _convert_config(config):
config_list = []
for k, v in config.items():
if v.lower() == 'true':
config_list.append('--' + k)
elif v.lower() != 'false':
config_list.extend(['--' + k] + v.split(' '))

return config_list


def _yield_configs(arg_parser, args, verbose=True):
_print = (lambda x: print(x)) if verbose else lambda x: x

if args.config:
config = _read_config(args.config)

for run_repeat, run_config in config:
print("-" * 50)
print("Config:")
print(run_config)

args_copy = copy.deepcopy(args)
config_list = _convert_config(run_config)
run_args = arg_parser.parse_args(config_list, namespace=args_copy)
run_args_dict = vars(run_args)

# set boolean values
for k, v in run_config.items():
if v.lower() == 'false':
run_args_dict[k] = False

print("Repeat %s times" % run_repeat)
print("-" * 50)

for iteration in range(run_repeat):
_print("Iteration %s" % iteration)
_print("-" * 50)

yield run_args, run_config, run_repeat

else:
yield args, None, None
17 changes: 17 additions & 0 deletions configs/example_eval.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[1]
label = conll04_eval
model_type = spert
model_path = data/models/conll04
tokenizer_path = data/models/conll04
dataset_path = data/datasets/conll04/conll04_test.json
types_path = data/datasets/conll04/conll04_types.json
eval_batch_size = 1
rel_filter_threshold = 0.4
size_embedding = 25
prop_drop = 0.1
max_span_size = 10
store_examples = true
sampling_processes = 4
sampling_limit = 100
max_pairs = 1000
log_path = data/log/
28 changes: 28 additions & 0 deletions configs/example_train.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[1]
label = conll04_train
model_type = spert
model_path = bert-base-cased
tokenizer_path = bert-base-cased
train_path = data/datasets/conll04/conll04_train.json
valid_path = data/datasets/conll04/conll04_dev.json
types_path = data/datasets/conll04/conll04_types.json
train_batch_size = 2
eval_batch_size = 1
neg_entity_count = 100
neg_relation_count = 100
epochs = 20
lr = 5e-5
lr_warmup = 0.1
weight_decay = 0.01
max_grad_norm = 1.0
rel_filter_threshold = 0.4
size_embedding = 25
prop_drop = 0.1
max_span_size = 10
store_examples = true
sampling_processes = 4
sampling_limit = 100
max_pairs = 1000
final_eval = true
log_path = data/log/
save_path = data/save/
Loading

0 comments on commit 9968fae

Please sign in to comment.