Skip to content

Commit

Permalink
Initial release of code for CSL paper.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 435163343
  • Loading branch information
Language Team authored and kentonl committed Mar 16, 2022
1 parent 3253325 commit bcc90d3
Show file tree
Hide file tree
Showing 112 changed files with 16,976 additions and 46 deletions.
107 changes: 107 additions & 0 deletions language/compgen/csl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Compositional Structure Learner (CSL)

This directory contains code for the paper "Improving Compositional Generalization with Latent Structure and Data Augmentation"
(Linlu Qiu, Peter Shaw, Panupong Pasupat, Paweł Krzysztof Nowak, Tal Linzen, Fei Sha, Kristina Toutanova)

TODO(petershaw): Add paper link.

## Synthetic Data Generation Pipeline

Below is a diagram describing the various components involved to generate
synthetic data.


![csl_flowchart](csl_flowchart.jpg)

The yellow boxes show the path of Python files relative to this directory.
The purple boxes are generated artifacts.

The input files consist of:

* Training and evaluation examples, in `.tsv` format with one line per example encoded as
`<input>\t<output>\n`.
* Three `.json` configuration files, one for each stage of the pipeline.
* An optional `.txt` file with seed rules for grammar induction.
* An optional `.txt` file with a Context-Free Grammar (CFG) defining valid outputs for the given task.

Examples of these input files are referenced below for each task studied in the
paper.

Note that in order for grammar induction to be
effective, it can be necessary for outputs to be encoded
as strings in a way that enables a decomposition within the QCFG formalism, e.g.
enables identifying corresponding sub-spans between inputs and outputs
effectively. It can also be helpful to determine a set of seed rules and a CFG
that constrains valid outputs. For most tasks, seed rules and the output CFG are
generated from the training data using various heuristics, as a preprocessing
step.

Once synthetic training examples are generated,
you can
combine the original and synthetic examples using the script
`augment/merge_tsvs.py`. These examples can be used to train
any downstream model. To reproduce the results in the paper, you can follow
these instructions below for fine-tuning and inference with T5 models:

https://github.com/google-research/language/blob/master/language/compgen/nqg/README.md#t5

## Setup and Prerequisites

All python scripts should be run using Python 3 while in the top-level of this
repository using `-m`. For example:

```shell
python -m language.compgen.csl.induction.search_main
```

Widely used prerequisite modules are `absl-py` and `tensorflow`.

We provide two versions of various modules, one that can be used with Apache
Beam for efficient parallel processing of larger datasets, and one that does not
have this dependency.

## Tasks

### SCAN

You can find the instructions for downloading and preprocessing the SCAN splits here:

https://github.com/google-research/language/blob/master/language/compgen/nqg/README.md#scan

The necessary configuration files for SCAN are located under `tasks/scan`. SCAN
does not use any seed rules or an output CFG.

### GeoQuery

You can find instructions for GeoQuery preprocessing here:

https://github.com/google-research/language/blob/master/language/compgen/nqg/README.md#geoquery

Additionally, the example IDs for the new TMCD and Template splits are located
in the `tasks/geoquery/splits` directory.

The configuration files and an output CFG for GeoQuery is located under
`tasks/geoquery`. The seed rules are generated using the script
`tasks/generate_exact_match_rules.py`.

### COGS

The relevant files for COGS are located in the `tasks/cogs`
directory. This includes the script `tasks/tools/preprocess_cogs_data.py` which
converts COGS examples to the variable-free intermediate representation we use.
Also included are the seed rules which were generated using an IBM alignment
model run on the training data, and the corresponding output CFG.

The dataset can be downloaded from: https://github.com/najoungkim/COGS

### SMCalFlow-CS

The configuration files for SMCalFlow-CS are included under `tasks/smcalflow`.
Additionally, there is a tool for heuristically filtering the training data
to discard some noisy examples `smcalflow/tools/filter_examples.py`. Also,
there are tools to generate seed rules and an output CFG from the training data,
`smcalflow/tools/generate_identity_rules.py` and
`smcalflow/tools/generate_target_cfg.py`. The generated seed rules are used
along with the manually specified seed rules, `smcalflow/manual_seed_rules.txt`.

There are also additional utilities for preprocessing the data for T5.
76 changes: 76 additions & 0 deletions language/compgen/csl/augment/generate_synthetic_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""QCFG-based data augmentation."""

from absl import app
from absl import flags
from language.compgen.csl.augment import sampler_utils
from language.compgen.nqg.tasks import tsv_utils

FLAGS = flags.FLAGS

flags.DEFINE_string("augment_config", "", "Augment config file.")

flags.DEFINE_string("output", "", "Output TSV file.")

flags.DEFINE_integer("num_examples", 1000,
"The number of examples to generate.")

flags.DEFINE_string("rules", "", "The QCFG rules.")

flags.DEFINE_string("target_grammar", "", "Optional target CFG.")

flags.DEFINE_string("model_dir", "", "Optional model directory.")

flags.DEFINE_string("checkpoint", "", "Checkpoint prefix, or None for latest.")

flags.DEFINE_string("model_config", "", "Model config file.")

flags.DEFINE_bool("verbose", False, "Whether to print debug output.")

flags.DEFINE_bool("allow_duplicates", True,
"Whether to allow duplicate examples.")

flags.DEFINE_bool("save_sampler", False, "Whether to save sampler.")


def main(unused_argv):
sampler = sampler_utils.get_sampler_wrapper(
augment_config=FLAGS.augment_config,
model_dir=FLAGS.model_dir,
model_config=FLAGS.model_config,
rules=FLAGS.rules,
target_grammar_file=FLAGS.target_grammar,
checkpoint=FLAGS.checkpoint,
verbose=FLAGS.verbose)

examples = []
if FLAGS.allow_duplicates:
while len(examples) < FLAGS.num_examples:
source, target = sampler.sample_example(len(examples))
examples.append((source, target))
else:
examples_set = set()
while len(examples_set) < FLAGS.num_examples:
source, target = sampler.sample_example(len(examples_set))
examples_set.add((source, target))
examples = list(examples_set)
tsv_utils.write_tsv(examples, FLAGS.output)
if FLAGS.save_sampler:
sampler.save()


if __name__ == "__main__":
app.run(main)
95 changes: 95 additions & 0 deletions language/compgen/csl/augment/generate_synthetic_examples_beam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""QCFG-based data augmentation using Beam."""

from absl import app
from absl import flags
from absl import logging
import apache_beam as beam
from language.compgen.csl.augment import sampler_utils
from language.compgen.nqg.tasks import tsv_utils


FLAGS = flags.FLAGS

flags.DEFINE_string("augment_config", "", "Augment config file.")

flags.DEFINE_string("output", "", "Output TSV file.")

flags.DEFINE_integer("num_examples", 1000,
"The number of examples to generate.")

flags.DEFINE_string("rules", "", "The QCFG rules.")

flags.DEFINE_string("target_grammar", "", "Optional target CFG.")

flags.DEFINE_string("model_dir", "", "Optional model directory.")

flags.DEFINE_string("checkpoint", "", "Checkpoint prefix, or None for latest.")

flags.DEFINE_string("model_config", "", "Model config file.")

flags.DEFINE_bool("verbose", False, "Whether to print debug output.")

flags.DEFINE_bool(
"allow_duplicates", True,
"Whether to allow duplicate examples. If not allow_duplicates, "
"the number of generated examples might be smaller than num_examples.")

flags.DEFINE_list(
"pipeline_options", ["--runner=DirectRunner"],
"A comma-separated list of command line arguments to be used as options "
"for the Beam Pipeline.")


def sample_example(i, sampler):
beam.metrics.Metrics.counter("SampleExamples", "num_examples").inc()
return sampler.sample_example(i)


def main(unused_argv):
sampler = sampler_utils.get_sampler_wrapper(
augment_config=FLAGS.augment_config,
model_dir=FLAGS.model_dir,
model_config=FLAGS.model_config,
rules=FLAGS.rules,
target_grammar_file=FLAGS.target_grammar,
checkpoint=FLAGS.checkpoint,
verbose=FLAGS.verbose)

def _sample_examples(pipeline):
seeds = range(FLAGS.num_examples)
examples = (
pipeline
| "Create" >> beam.Create(seeds)
| "SampleExamples" >> beam.Map(sample_example, sampler=sampler)
| "Format" >> beam.Map(lambda ex: "%s\t%s" % (ex[0], ex[1])))
if not FLAGS.allow_duplicates:
examples = examples | "RemoveDuplicates" >> beam.Distinct()
_ = examples | "WriteExamples" >> beam.io.WriteToText(FLAGS.output)

pipeline_options = beam.options.pipeline_options.PipelineOptions(
FLAGS.pipeline_options)
with beam.Pipeline(pipeline_options) as pipeline:
_sample_examples(pipeline)

metrics = pipeline.result.metrics().query()
for counter in metrics["counters"]:
logging.info("%s: %s", counter.key.metric.name, counter.committed)
tsv_utils.merge_shared_tsvs(FLAGS.output)


if __name__ == "__main__":
app.run(main)
Loading

0 comments on commit bcc90d3

Please sign in to comment.