Skip to content

Commit

Permalink
add cosql editsql
Browse files Browse the repository at this point in the history
  • Loading branch information
Rui Zhang committed Dec 1, 2019
1 parent 05ded8f commit 29a433c
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 4 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Download Pretrained BERT model from [here](https://drive.google.com/file/d/1f_LE
### Run Spider experiment
First, download [Spider](https://yale-lily.github.io/spider). Then please follow

- `run_spider_editsql.sh`. We saved our experimental logs at `logs/logs_spider_editsql`
- `run_spider_editsql.sh`. We saved our experimental logs at `logs/logs_spider_editsql`. The dev results can be reproduced by `test_spider_editsql.sh` with the trained model `logs/logs_spider_editsql/save_12`.

This reproduces the Spider result in "Editing-Based SQL Query Generation for Cross-Domain Context-Dependent Questions".

Expand Down Expand Up @@ -140,7 +140,8 @@ This reproduces the SParC result in "Editing-Based SQL Query Generation for Cros

First, download CoSQL from [here](https://yale-lily.github.io/cosql). Then please follow

- `run_cosql_cdseq2seq.sh`. We saved our experimental logs at `logs/logs_cosql_cdseq2seq`
- `run_cosql_cdseq2seq.sh`. We saved our experimental logs at `logs/logs_cosql_cdseq2seq`.
- `run_cosql_editsql.sh`. We saved our experimental logs at `logs/logs_cosql_editsql`. The dev results can be reproduced by `test_cosql_editsql.sh` with the trained model downloaded from [here](https://drive.google.com/file/d/1ggf05rLVUpqamkEFbhu2CX35-PTGpFx4/view?usp=sharing) and put under `logs/logs_cosql_editsql/save_12_cosql_editsql`.

This reproduces the SQL-grounded dialog state tracking result in "CoSQL: A Conversational Text-to-SQL Challenge Towards Cross-Domain Natural Language Interfaces to Databases".

Expand Down
10 changes: 8 additions & 2 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def check_oov(format_sql_final, output_vocab, schema_tokens):
for sql_tok in format_sql_final.split():
if not (sql_tok in schema_tokens or sql_tok in output_vocab):
print('OOV!', sql_tok)
raise Exception('OOV')


def normalize_space(format_sql):
Expand Down Expand Up @@ -401,7 +402,11 @@ def read_data_json(split_json, interaction_list, database_schemas, column_names,
continue

if remove_from:
turn_sql_parse = parse_sql(turn_sql, db_id, column_names[db_id], output_vocab, schema_tokens[db_id], database_schemas[db_id])
try:
turn_sql_parse = parse_sql(turn_sql, db_id, column_names[db_id], output_vocab, schema_tokens[db_id], database_schemas[db_id])
except:
print('continue')
continue
else:
turn_sql_parse = turn_sql

Expand All @@ -411,7 +416,8 @@ def read_data_json(split_json, interaction_list, database_schemas, column_names,
turn_utterance = turn['utterance']

interaction['interaction'].append({'utterance': turn_utterance, 'sql': turn_sql_parse})
interaction_list[db_id].append(interaction)
if len(interaction['interaction']) > 0:
interaction_list[db_id].append(interaction)

return interaction_list

Expand Down
44 changes: 44 additions & 0 deletions run_cosql_editsql.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#! /bin/bash

# 1. preprocess dataset by the following. It will produce data/cosql_data_removefrom/

python3 preprocess.py --dataset=cosql --remove_from

# 2. train and evaluate.
# the result (models, logs, prediction outputs) are saved in $LOGDIR

GLOVE_PATH="/home/lily/rz268/dialog2sql/word_emb/glove.840B.300d.txt" # you need to change this
LOGDIR="logs_cosql_editsql"

CUDA_VISIBLE_DEVICES=0 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
--embedding_filename=$GLOVE_PATH \
--data_directory="processed_data_cosql_removefrom" \
--input_key="utterance" \
--state_positional_embeddings=1 \
--discourse_level_lstm=1 \
--use_utterance_attention=1 \
--use_previous_query=1 \
--use_query_attention=1 \
--use_copy_switch=1 \
--use_schema_encoder=1 \
--use_schema_attention=1 \
--use_encoder_attention=1 \
--use_bert=1 \
--bert_type_abb=uS \
--fine_tune_bert=1 \
--use_schema_self_attention=1 \
--use_schema_encoder_2=1 \
--interaction_level=1 \
--reweight_batch=1 \
--freeze=1 \
--train=1 \
--logdir=$LOGDIR \
--evaluate=1 \
--evaluate_split="valid" \
--use_predicted_queries=1

# 3. get evaluation result

python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file $LOGDIR/valid_use_predicted_queries_predictions.json --remove_from
44 changes: 44 additions & 0 deletions test_cosql_editsql.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#! /bin/bash

# 1. preprocess dataset by the following. It will produce data/cosql_data_removefrom/

# python3 preprocess.py --dataset=cosql --remove_from

# 2. train and evaluate.
# the result (models, logs, prediction outputs) are saved in $LOGDIR

GLOVE_PATH="/home/lily/rz268/dialog2sql/word_emb/glove.840B.300d.txt" # you need to change this
LOGDIR="logs/logs_cosql_editsql"

CUDA_VISIBLE_DEVICES=0 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
--embedding_filename=$GLOVE_PATH \
--data_directory="processed_data_cosql_removefrom" \
--input_key="utterance" \
--state_positional_embeddings=1 \
--discourse_level_lstm=1 \
--use_utterance_attention=1 \
--use_previous_query=1 \
--use_query_attention=1 \
--use_copy_switch=1 \
--use_schema_encoder=1 \
--use_schema_attention=1 \
--use_encoder_attention=1 \
--use_bert=1 \
--bert_type_abb=uS \
--fine_tune_bert=1 \
--use_schema_self_attention=1 \
--use_schema_encoder_2=1 \
--interaction_level=1 \
--reweight_batch=1 \
--freeze=1 \
--logdir=$LOGDIR \
--evaluate=1 \
--evaluate_split="valid" \
--use_predicted_queries=1 \
--save_file="$LOGDIR/save_12_cosql_editsql"

# 3. get evaluation result

python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file $LOGDIR/valid_use_predicted_queries_predictions.json --remove_from

0 comments on commit 29a433c

Please sign in to comment.