-
This repository contains the code to train the
tagger
andgenerator
modules. -
Apart from scripts to train the modules, it also has scripts needed to run inference on the test set and to run evaluation for metrics like
BLEU
,ROUGE
, andMETEOR
.
-
Both
tagger
andgenerator
are seq2seq models that require parallel data generated by the data prep module. -
The parallel datasets are:
- Tagger:
entagged_parallel.{split}.en
→entagged_parallel.{split}.tagged
- Generated:
engenerated_parallel.{split}.en
→engenerated_parallel.{split}.generated
(where{split}
is either train, test, or dev.)
- Tagger:
bash scripts/prepare_bpe.sh [tagged|generated] {base_folder}
Where:
base_folder
: The folder in which the data files are stored (argument used in creation of training data)
bash scripts/train_tagger.sh tagged {handle} {base_folder}
Where:
handle:
This is an identifier used to bucketize models trained on different datasets. Models on eachhandle
are stored seperate folders with names indexed by{handle}
, within the{models}
directory.base_folder:
The folder in which the data files are stored (argument used in creation of training data).
bash scripts/train_generator.sh generated {handle} {base_folder}
Where:
handle:
This is an identifier used to bucketize models trained on different datasets. Models on each handle are stored seperate folders with names indexed by{handle}
, within the{models}
directory.base_folder:
The folder in which the data files are stored (argument used in creation of training data).
bash scripts/inference.sh {input_file} {jobname}\
tagged generated\
{handle}\
{style_0_label} {style_1_label}\
{base_folder} {device}
Where:
input_file:
The input test file which needs to be transferred. This is the raw text file, with one sentence per line.jobname:
A unique identifier for the inference job.handle:
dataset argument we pass when we traintagger
orgenerator
-- used to identify model paths fortagger
andgenerator
.style_0_label:
A label for style 0style_1_label:
A label for style 1base_folder:
The folder in which the data files are stored (argument used in creation of training data)device:
gpu id
bash run_context_eval.sh {hypothesis_filepath} {reference_filepath}
Where:
hypothesis_filepath:
The full path to the transferred output from trained model (hypothesis).reference_filepath:
The full path to the ideal output (for BLEU-r) or the original input file (for BLEU-s).
The trained models can be found here.
- The code for evaluation has been partially borrowed from https://github.com/Maluuba/nlg-eval
- Most of the code for the training pipeline has been borrowed from https://github.com/pmichel31415/jsalt-2019-mt-tutorial