-
Notifications
You must be signed in to change notification settings - Fork 351
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial release of code for CSL paper.
PiperOrigin-RevId: 435163343
- Loading branch information
Showing
112 changed files
with
16,976 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Compositional Structure Learner (CSL) | ||
|
||
This directory contains code for the paper "Improving Compositional Generalization with Latent Structure and Data Augmentation" | ||
(Linlu Qiu, Peter Shaw, Panupong Pasupat, Paweł Krzysztof Nowak, Tal Linzen, Fei Sha, Kristina Toutanova) | ||
|
||
TODO(petershaw): Add paper link. | ||
|
||
## Synthetic Data Generation Pipeline | ||
|
||
Below is a diagram describing the various components involved to generate | ||
synthetic data. | ||
|
||
|
||
![csl_flowchart](csl_flowchart.jpg) | ||
|
||
The yellow boxes show the path of Python files relative to this directory. | ||
The purple boxes are generated artifacts. | ||
|
||
The input files consist of: | ||
|
||
* Training and evaluation examples, in `.tsv` format with one line per example encoded as | ||
`<input>\t<output>\n`. | ||
* Three `.json` configuration files, one for each stage of the pipeline. | ||
* An optional `.txt` file with seed rules for grammar induction. | ||
* An optional `.txt` file with a Context-Free Grammar (CFG) defining valid outputs for the given task. | ||
|
||
Examples of these input files are referenced below for each task studied in the | ||
paper. | ||
|
||
Note that in order for grammar induction to be | ||
effective, it can be necessary for outputs to be encoded | ||
as strings in a way that enables a decomposition within the QCFG formalism, e.g. | ||
enables identifying corresponding sub-spans between inputs and outputs | ||
effectively. It can also be helpful to determine a set of seed rules and a CFG | ||
that constrains valid outputs. For most tasks, seed rules and the output CFG are | ||
generated from the training data using various heuristics, as a preprocessing | ||
step. | ||
|
||
Once synthetic training examples are generated, | ||
you can | ||
combine the original and synthetic examples using the script | ||
`augment/merge_tsvs.py`. These examples can be used to train | ||
any downstream model. To reproduce the results in the paper, you can follow | ||
these instructions below for fine-tuning and inference with T5 models: | ||
|
||
https://github.com/google-research/language/blob/master/language/compgen/nqg/README.md#t5 | ||
|
||
## Setup and Prerequisites | ||
|
||
All python scripts should be run using Python 3 while in the top-level of this | ||
repository using `-m`. For example: | ||
|
||
```shell | ||
python -m language.compgen.csl.induction.search_main | ||
``` | ||
|
||
Widely used prerequisite modules are `absl-py` and `tensorflow`. | ||
|
||
We provide two versions of various modules, one that can be used with Apache | ||
Beam for efficient parallel processing of larger datasets, and one that does not | ||
have this dependency. | ||
|
||
## Tasks | ||
|
||
### SCAN | ||
|
||
You can find the instructions for downloading and preprocessing the SCAN splits here: | ||
|
||
https://github.com/google-research/language/blob/master/language/compgen/nqg/README.md#scan | ||
|
||
The necessary configuration files for SCAN are located under `tasks/scan`. SCAN | ||
does not use any seed rules or an output CFG. | ||
|
||
### GeoQuery | ||
|
||
You can find instructions for GeoQuery preprocessing here: | ||
|
||
https://github.com/google-research/language/blob/master/language/compgen/nqg/README.md#geoquery | ||
|
||
Additionally, the example IDs for the new TMCD and Template splits are located | ||
in the `tasks/geoquery/splits` directory. | ||
|
||
The configuration files and an output CFG for GeoQuery is located under | ||
`tasks/geoquery`. The seed rules are generated using the script | ||
`tasks/generate_exact_match_rules.py`. | ||
|
||
### COGS | ||
|
||
The relevant files for COGS are located in the `tasks/cogs` | ||
directory. This includes the script `tasks/tools/preprocess_cogs_data.py` which | ||
converts COGS examples to the variable-free intermediate representation we use. | ||
Also included are the seed rules which were generated using an IBM alignment | ||
model run on the training data, and the corresponding output CFG. | ||
|
||
The dataset can be downloaded from: https://github.com/najoungkim/COGS | ||
|
||
### SMCalFlow-CS | ||
|
||
The configuration files for SMCalFlow-CS are included under `tasks/smcalflow`. | ||
Additionally, there is a tool for heuristically filtering the training data | ||
to discard some noisy examples `smcalflow/tools/filter_examples.py`. Also, | ||
there are tools to generate seed rules and an output CFG from the training data, | ||
`smcalflow/tools/generate_identity_rules.py` and | ||
`smcalflow/tools/generate_target_cfg.py`. The generated seed rules are used | ||
along with the manually specified seed rules, `smcalflow/manual_seed_rules.txt`. | ||
|
||
There are also additional utilities for preprocessing the data for T5. |
76 changes: 76 additions & 0 deletions
76
language/compgen/csl/augment/generate_synthetic_examples.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The Google AI Language Team Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""QCFG-based data augmentation.""" | ||
|
||
from absl import app | ||
from absl import flags | ||
from language.compgen.csl.augment import sampler_utils | ||
from language.compgen.nqg.tasks import tsv_utils | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_string("augment_config", "", "Augment config file.") | ||
|
||
flags.DEFINE_string("output", "", "Output TSV file.") | ||
|
||
flags.DEFINE_integer("num_examples", 1000, | ||
"The number of examples to generate.") | ||
|
||
flags.DEFINE_string("rules", "", "The QCFG rules.") | ||
|
||
flags.DEFINE_string("target_grammar", "", "Optional target CFG.") | ||
|
||
flags.DEFINE_string("model_dir", "", "Optional model directory.") | ||
|
||
flags.DEFINE_string("checkpoint", "", "Checkpoint prefix, or None for latest.") | ||
|
||
flags.DEFINE_string("model_config", "", "Model config file.") | ||
|
||
flags.DEFINE_bool("verbose", False, "Whether to print debug output.") | ||
|
||
flags.DEFINE_bool("allow_duplicates", True, | ||
"Whether to allow duplicate examples.") | ||
|
||
flags.DEFINE_bool("save_sampler", False, "Whether to save sampler.") | ||
|
||
|
||
def main(unused_argv): | ||
sampler = sampler_utils.get_sampler_wrapper( | ||
augment_config=FLAGS.augment_config, | ||
model_dir=FLAGS.model_dir, | ||
model_config=FLAGS.model_config, | ||
rules=FLAGS.rules, | ||
target_grammar_file=FLAGS.target_grammar, | ||
checkpoint=FLAGS.checkpoint, | ||
verbose=FLAGS.verbose) | ||
|
||
examples = [] | ||
if FLAGS.allow_duplicates: | ||
while len(examples) < FLAGS.num_examples: | ||
source, target = sampler.sample_example(len(examples)) | ||
examples.append((source, target)) | ||
else: | ||
examples_set = set() | ||
while len(examples_set) < FLAGS.num_examples: | ||
source, target = sampler.sample_example(len(examples_set)) | ||
examples_set.add((source, target)) | ||
examples = list(examples_set) | ||
tsv_utils.write_tsv(examples, FLAGS.output) | ||
if FLAGS.save_sampler: | ||
sampler.save() | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
95 changes: 95 additions & 0 deletions
95
language/compgen/csl/augment/generate_synthetic_examples_beam.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The Google AI Language Team Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""QCFG-based data augmentation using Beam.""" | ||
|
||
from absl import app | ||
from absl import flags | ||
from absl import logging | ||
import apache_beam as beam | ||
from language.compgen.csl.augment import sampler_utils | ||
from language.compgen.nqg.tasks import tsv_utils | ||
|
||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_string("augment_config", "", "Augment config file.") | ||
|
||
flags.DEFINE_string("output", "", "Output TSV file.") | ||
|
||
flags.DEFINE_integer("num_examples", 1000, | ||
"The number of examples to generate.") | ||
|
||
flags.DEFINE_string("rules", "", "The QCFG rules.") | ||
|
||
flags.DEFINE_string("target_grammar", "", "Optional target CFG.") | ||
|
||
flags.DEFINE_string("model_dir", "", "Optional model directory.") | ||
|
||
flags.DEFINE_string("checkpoint", "", "Checkpoint prefix, or None for latest.") | ||
|
||
flags.DEFINE_string("model_config", "", "Model config file.") | ||
|
||
flags.DEFINE_bool("verbose", False, "Whether to print debug output.") | ||
|
||
flags.DEFINE_bool( | ||
"allow_duplicates", True, | ||
"Whether to allow duplicate examples. If not allow_duplicates, " | ||
"the number of generated examples might be smaller than num_examples.") | ||
|
||
flags.DEFINE_list( | ||
"pipeline_options", ["--runner=DirectRunner"], | ||
"A comma-separated list of command line arguments to be used as options " | ||
"for the Beam Pipeline.") | ||
|
||
|
||
def sample_example(i, sampler): | ||
beam.metrics.Metrics.counter("SampleExamples", "num_examples").inc() | ||
return sampler.sample_example(i) | ||
|
||
|
||
def main(unused_argv): | ||
sampler = sampler_utils.get_sampler_wrapper( | ||
augment_config=FLAGS.augment_config, | ||
model_dir=FLAGS.model_dir, | ||
model_config=FLAGS.model_config, | ||
rules=FLAGS.rules, | ||
target_grammar_file=FLAGS.target_grammar, | ||
checkpoint=FLAGS.checkpoint, | ||
verbose=FLAGS.verbose) | ||
|
||
def _sample_examples(pipeline): | ||
seeds = range(FLAGS.num_examples) | ||
examples = ( | ||
pipeline | ||
| "Create" >> beam.Create(seeds) | ||
| "SampleExamples" >> beam.Map(sample_example, sampler=sampler) | ||
| "Format" >> beam.Map(lambda ex: "%s\t%s" % (ex[0], ex[1]))) | ||
if not FLAGS.allow_duplicates: | ||
examples = examples | "RemoveDuplicates" >> beam.Distinct() | ||
_ = examples | "WriteExamples" >> beam.io.WriteToText(FLAGS.output) | ||
|
||
pipeline_options = beam.options.pipeline_options.PipelineOptions( | ||
FLAGS.pipeline_options) | ||
with beam.Pipeline(pipeline_options) as pipeline: | ||
_sample_examples(pipeline) | ||
|
||
metrics = pipeline.result.metrics().query() | ||
for counter in metrics["counters"]: | ||
logging.info("%s: %s", counter.key.metric.name, counter.committed) | ||
tsv_utils.merge_shared_tsvs(FLAGS.output) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
Oops, something went wrong.