Skip to content

Commit

Permalink
Use TFDS for loading data.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 331730805
  • Loading branch information
dfurrer authored and copybara-github committed Sep 15, 2020
1 parent 265b433 commit 4c8c5ca
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 103 deletions.
11 changes: 6 additions & 5 deletions cfq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ This library requires Python3 and the following Python3 libraries:
* [absl-py](https://pypi.org/project/absl-py/)
* [tensorflow](https://www.tensorflow.org/)
* [tensor2tensor](https://github.com/tensorflow/tensor2tensor)
* [tensorflow-datasets](https://www.tensorflow.org/datasets)

We recommend getting [pip3](https://pip.pypa.io/en/stable/) and then running the
following command, which will install all required libraries in one go:
Expand All @@ -57,18 +58,18 @@ following command, which will install all required libraries in one go:
sudo pip3 install -r requirements.txt
```

Note that Tensor2Tensor is no longer updated and is based on Tensorflow 1 which
is only available for Python <= 3.7.

## Training and evaluating a model

First download the CFQ dataset (link above), and ensure the dataset and the
splits directory are in the same directory as this library (e.g. by unpacking
the file in the library directory). In order to train and evaluate a model,
run the following:
In order to train and evaluate a model, run the following:

```shell
bash run_experiment.sh
```

This will run preprocessing on the dataset and train an LSTM model with
This will download and preprocessing the dataset, then train an LSTM model with
attention on the random split of the CFQ dataset, after which it will directly
be evaluated.

Expand Down
89 changes: 19 additions & 70 deletions cfq/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,70 +17,17 @@
"""Utils for preprocessing the CFQ dataset."""

import collections
import json
import os
import re
import string
from typing import Any, Dict, List, Tuple

from absl import logging

from tensorflow.compat.v1.io import gfile
import tensorflow_datasets as tfds

Dataset = Dict[str, List[Tuple[str, str]]]

_QUESTION_FIELD = 'questionPatternModEntities'
_QUERY_FIELD = 'sparqlPatternModEntities'


def _scrub_json(content):
"""Reduce JSON by filtering out only the fields of interest."""
# Loading of json data with the standard Python library is very inefficient:
# For the 4GB dataset file it requires more than 40GB of RAM and takes 3min.
# There are more efficient libraries but in order to avoid additional
# dependencies we use a simple (perhaps somewhat brittle) regexp to reduce
# the content to only what is needed. This takes 1min to execute but
# afterwards loading requires only 500MB or RAM and is done in 2s.
regex = re.compile(
r'("%s":\s*"[^"]*").*?("%s":\s*"[^"]*")' %
(_QUESTION_FIELD, _QUERY_FIELD), re.DOTALL)
return '[' + ','.join([
'{' + m.group(1) + ',' + m.group(2) + '}' for m in regex.finditer(content)
]) + ']'


def load_json(path, scrub = False):
logging.info('Reading json from %s into memory...', path)
with gfile.GFile(path) as f:
if scrub:
data = json.loads(_scrub_json(f.read()))
else:
data = json.load(f)
logging.info('Successfully loaded json data from %s into memory.', path)
return data


def load_scan(path):
"""Read original scan task data and convert into CFQ-style json format."""
logging.info('Reading SCAN tasks from %s.', path)

def parse(infile):
for line in infile.read().split('\n'):
if not line.startswith('IN: '):
continue
commands, actions = line[len('IN: '):].strip().split(' OUT: ', 1)
yield {_QUESTION_FIELD: commands, _QUERY_FIELD: actions}

return list(parse(gfile.GFile(path)))


def load_dataset(path):
"""Load dataset from .json or SCAN task format."""
if path[-5:] == '.json':
return load_json(path, scrub=True)
else:
return load_scan(path)


def tokenize_punctuation(text):
text = map(lambda c: ' %s ' % c if c in string.punctuation else c, text)
Expand Down Expand Up @@ -113,23 +60,25 @@ def get_encode_decode_pair(sample):
return (encode_text, decode_text)


def get_dataset(samples, split):
"""Creates a dataset by taking @split from @samples."""
logging.info('Retrieving splits...')
split_names = ['train', 'dev', 'test']
idx_names = [f'{s}Idxs' for s in split_names]
def get_dataset_from_tfds(dataset, split):
"""..."""
logging.info('Loading dataset via TFDS.')
allsplits = tfds.load(dataset + '/' + split, as_supervised=True)
split_names = {'train': 'train', 'validation': 'dev', 'test': 'test'}
if dataset == 'scan':
# scan has 'train' and 'test' sets only. We call the test set dev in our
# output to keep the bash script simple.
split_names = {'train': 'train', 'test': 'dev'}

dataset = collections.defaultdict(list)
if not set(idx_names) <= split.keys():
logging.fatal('Invalid split: JSON should contain fields %s.', idx_names)
return dataset
for split_name, idx_name in zip(split_names, idx_names):
logging.info(
' Retrieving %s (%s instances)', split_name, len(split[idx_name]))
for idx in split[idx_name]:
dataset[split_name].append(get_encode_decode_pair(samples[idx]))

size_str = ', '.join('%s=%s' %(s, len(dataset[s])) for s in split_names)
logging.info('Finished retrieving splits. Size: %s', size_str)
for tfds_split_name, cfq_split_name in split_names.items():
for raw_x, raw_y in tfds.as_numpy(allsplits[tfds_split_name]):
encode_decode_pair = (tokenize_punctuation(raw_x.decode()),
preprocess_sparql(raw_y.decode()))
dataset[cfq_split_name].append(encode_decode_pair)

size_str = ', '.join(f'{s}={len(dataset[s])}' for s in split_names)
logging.info('Finished loading splits. Size: %s', size_str)
return dataset


Expand Down
15 changes: 4 additions & 11 deletions cfq/preprocess_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,30 @@
from __future__ import division
from __future__ import print_function

import os

from absl import app
from absl import flags

import preprocess as preprocessor

FLAGS = flags.FLAGS

flags.DEFINE_string('dataset_path', None, 'Path to the JSON file containing '
'the dataset.')
flags.DEFINE_string('dataset', None,
'Name of the TFDS dataset. Use cfq or scan.')

flags.DEFINE_string('split_path', None, 'Path to the JSON file containing '
flags.DEFINE_string('split', None, 'Name of the to the JSON file containing '
'split information.')

flags.DEFINE_string('save_path', None, 'Path to the directory where to '
'save the files to.')

flags.mark_flag_as_required('save_path')

flags.register_validator('dataset_path', os.path.exists, 'Dataset not found.')
flags.register_validator('split_path', os.path.exists, 'Split not found.')


def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')

dataset = preprocessor.get_dataset(
preprocessor.load_dataset(FLAGS.dataset_path),
preprocessor.load_json(FLAGS.split_path))
dataset = preprocessor.get_dataset_from_tfds(FLAGS.dataset, FLAGS.split)
preprocessor.write_dataset(dataset, FLAGS.save_path)
token_vocab = preprocessor.get_token_vocab(FLAGS.save_path)
preprocessor.write_token_vocab(token_vocab, FLAGS.save_path)
Expand Down
1 change: 1 addition & 0 deletions cfq/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ absl-py>=0.8.1
tensorflow>=1.14.0,<2.0
tensor2tensor>=1.14.1
dataclasses
tensorflow-datasets>=3.0
23 changes: 6 additions & 17 deletions cfq/run_experiment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@ hparams_set="cfq_lstm_attention_multi"
# We report experiments with 35,000 steps in our paper.
train_steps="35000"

# URL to the CFQ dataset.
dataset_url="https://storage.cloud.google.com/cfq_dataset/cfq.tar.gz"
# The dataset to use (cfq or scan)
dataset="cfq"

# Local path to the dataset (after it has been downloaded).
dataset_local_path="dataset.json"

# Location of the dataset split to run the experiment for.
# The split of the dataset (random, mcd1, mcd2, mcd3).
split="mcd1"
split_path="splits/${split}.json"

# Evaluation results will be written to this path.
eval_results_path="evaluation-${model}-${split}.txt"
Expand All @@ -58,16 +54,9 @@ decode_path="${save_path}/dev/dev_decode.txt"
decode_inferred_path="${save_path}/dev/dev_decode_inferred.txt"

# ================= Pipeline ================
# Download dataset if it doesn't exist yet.
if [[ ! -f "${dataset_local_path}" || ! -f "${split_path}" ]]; then
echo "ERROR: Dataset not found."
echo "Please download the dataset first from ${dataset_url}!"
echo "See further instructions in the README."
exit 1
fi

python3 -m preprocess_main --dataset_path="${dataset_local_path}" \
--split_path="${split_path}" --save_path="${save_path}"

python3 -m preprocess_main --dataset="${dataset}" \
--split="${split}" --save_path="${save_path}"

t2t-datagen --t2t_usr_dir="${work_dir}/cfq/" --data_dir="${save_path}" \
--problem="${problem}" --tmp_dir="${tmp_path}"
Expand Down

0 comments on commit 4c8c5ca

Please sign in to comment.