Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkkev committed Mar 11, 2020
2 parents a71d649 + 423dc8b commit 96d59a0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ Use `build_pretraining_dataset.py` to create a pre-training dataset from a dump
* `--vocab-file`: File defining the wordpiece vocabulary.
* `--output-dir`: Where to write out ELECTRA examples.
* `--max-seq-length`: The number of tokens per example (128 by default).
* `--num-processes`: If >1 parallelize across multiple processes (1 by default)
* `--num-processes`: If >1 parallelize across multiple processes (1 by default).
* `--blanks-separate-docs`: Whether blank lines indicate document boundaries (True by default).
* `--do-lower-case/--no-lower-case`: Whether to lower case the input text (True by default).

Use `run_pretraining.py` to pre-train an ELECTRA model. It has the following arguments:

Expand Down Expand Up @@ -158,7 +160,7 @@ If you use this code for your publication, please cite the original paper:
```
@inproceedings{clark2020electra,
title = {{ELECTRA}: Pre-training Text Encoders as Discriminators Rather Than Generators},
author = {Kevin Clark and Minh-Thang Luong and and Quoc V. Le and Christopher D. Manning},
author = {Kevin Clark and Minh-Thang Luong and Quoc V. Le and Christopher D. Manning},
booktitle = {ICLR},
year = {2020}
}
Expand Down
8 changes: 7 additions & 1 deletion build_openwebtext_pretraining_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def log(*args):
output_dir=os.path.join(args.data_dir, "pretrain_tfrecords"),
max_seq_length=args.max_seq_length,
num_jobs=args.num_processes,
blanks_separate_docs=False
blanks_separate_docs=False,
do_lower_case=args.do_lower_case
)
log("Writing tf examples")
fnames = sorted(tf.io.gfile.listdir(owt_dir))
Expand Down Expand Up @@ -78,6 +79,11 @@ def main():
help="Number of tokens per example.")
parser.add_argument("--num-processes", default=1, type=int,
help="Parallelize across multiple processes.")
parser.add_argument("--do-lower-case", dest='do_lower_case',
action='store_true', help="Lower case input text.")
parser.add_argument("--no-lower-case", dest='do_lower_case',
action='store_false', help="Don't lower case input text.")
parser.set_defaults(do_lower_case=True)
args = parser.parse_args()

utils.rmkdir(os.path.join(args.data_dir, "pretrain_tfrecords"))
Expand Down
13 changes: 10 additions & 3 deletions build_pretraining_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ class ExampleWriter(object):
"""Writes pre-training examples to disk."""

def __init__(self, job_id, vocab_file, output_dir, max_seq_length,
num_jobs, blanks_separate_docs, num_out_files=1000):
num_jobs, blanks_separate_docs, do_lower_case,
num_out_files=1000):
self._blanks_separate_docs = blanks_separate_docs
tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file,
do_lower_case=True)
do_lower_case=do_lower_case)
self._example_builder = ExampleBuilder(tokenizer, max_seq_length)
self._writers = []
for i in range(num_out_files):
Expand Down Expand Up @@ -169,7 +170,8 @@ def log(*args):
output_dir=args.output_dir,
max_seq_length=args.max_seq_length,
num_jobs=args.num_processes,
blanks_separate_docs=args.blanks_separate_docs
blanks_separate_docs=args.blanks_separate_docs,
do_lower_case=args.do_lower_case
)
log("Writing tf examples")
fnames = sorted(tf.io.gfile.listdir(args.corpus_dir))
Expand Down Expand Up @@ -204,6 +206,11 @@ def main():
help="Parallelize across multiple processes.")
parser.add_argument("--blanks-separate-docs", default=True, type=bool,
help="Whether blank lines indicate document boundaries.")
parser.add_argument("--do-lower-case", dest='do_lower_case',
action='store_true', help="Lower case input text.")
parser.add_argument("--no-lower-case", dest='do_lower_case',
action='store_false', help="Don't lower case input text.")
parser.set_defaults(do_lower_case=True)
args = parser.parse_args()

utils.rmkdir(args.output_dir)
Expand Down

0 comments on commit 96d59a0

Please sign in to comment.