Skip to content

Commit

Permalink
adding example 4
Browse files Browse the repository at this point in the history
  • Loading branch information
saransh-mehta committed Jun 11, 2020
1 parent 3c07346 commit ebefeaf
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/source/data_transformations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Sample transform functions
^^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: utils.tranform_functions
:members: snips_intent_ner_to_tsv, snli_entailment_to_tsv,create_fragment_detection_tsv,
msmarco_answerability_detection_to_tsv, bio_ner_to_tsv, msmarco_query_type_to_tsv, qqp_query_similarity_to_tsv
msmarco_answerability_detection_to_tsv, msmarco_query_type_to_tsv, bio_ner_to_tsv, msmarco_query_type_to_tsv, qqp_query_similarity_to_tsv

Your own transform function
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion docs/source/shared_encoder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ For evaluating the performance on dev and test sets during training, we provide

.. automodule:: utils.eval_metrics
:members: classification_accuracy, classification_f1_score, seqeval_f1_score,
seqeval_precision, seqeval_recall, snips_f1_score, snips_precision, snips_recall
seqeval_precision, seqeval_recall, snips_f1_score, snips_precision, snips_recall, classification_recall_score

12 changes: 6 additions & 6 deletions examples/entailment_detection/entailment_snli.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
"!python ../../data_preparation.py \\\n",
" --task_file 'tasks_file_snli.yml' \\\n",
" --data_dir '../../data' \\\n",
" --max_seq_len 384"
" --max_seq_len 128"
]
},
{
Expand All @@ -135,13 +135,13 @@
" --task_file 'tasks_file_snli.yml' \\\n",
" --out_dir 'snli_entailment_bert_base' \\\n",
" --epochs 3 \\\n",
" --train_batch_size 8 \\\n",
" --eval_batch_size 16 \\\n",
" --grad_accumulation_steps 2 \\\n",
" --log_per_updates 50 \\\n",
" --train_batch_size 64 \\\n",
" --eval_batch_size 64 \\\n",
" --grad_accumulation_steps 1 \\\n",
" --log_per_updates 100 \\\n",
" --eval_while_train True \\\n",
" --test_while_train True \\\n",
" --max_seq_len 384 \\\n",
" --max_seq_len 128 \\\n",
" --silent True "
]
},
Expand Down
97 changes: 97 additions & 0 deletions examples/query_type_detection/query_type_detection.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wget https://msmarco.blob.core.windows.net/msmarco/train_v2.1.json.gz -P msmarco_qna_data\n",
"!wget https://msmarco.blob.core.windows.net/msmarco/dev_v2.1.json.gz -P msmarco_qna_data\n",
"!wget https://msmarco.blob.core.windows.net/msmarco/eval_v2.1_public.json.gz -P msmarco_qna_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!gunzip msmarco_qna_data/train_v2.1.json.gz\n",
"!gunzip msmarco_qna_data/dev_v2.1.json.gz\n",
"!gunzip msmarco_qna_data/eval_v2.1_public.json.gz"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!python ../../data_transformations.py \\\n",
" --transform_file 'transform_file_querytype.yml'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!python ../../data_preparation.py \\\n",
" --task_file 'tasks_file_querytype.yml' \\\n",
" --data_dir '../../data' \\\n",
" --max_seq_len 60"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!python ../../train.py \\\n",
" --data_dir '../../data/bert-base-uncased_prepared_data' \\\n",
" --task_file 'tasks_file_querytype.yml' \\\n",
" --out_dir 'msmarco_querytype_bert_base' \\\n",
" --epochs 3 \\\n",
" --train_batch_size 64 \\\n",
" --eval_batch_size 64 \\\n",
" --grad_accumulation_steps 1 \\\n",
" --log_per_updates 100 \\\n",
" --eval_while_train True \\\n",
" --test_while_train True \\\n",
" --max_seq_len 60 \\\n",
" --silent True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
18 changes: 18 additions & 0 deletions examples/query_type_detection/tasks_file_querytype.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
querytype:
model_type: BERT
config_name: bert-base-uncased
dropout_prob: 0.2
label_map_or_file:
- DESCRIPTION
- ENTITY
- LOCATION
- NUMERIC
- PERSON
metrics:
- classification_accuracy
loss_type: CrossEntropyLoss
task_type: SingleSenClassification
file_names:
- querytype_train_v2.1.tsv
- querytype_dev_v2.1.tsv
- querytype_eval_v2.1_public.tsv
11 changes: 11 additions & 0 deletions examples/query_type_detection/transform_file_querytype.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
transform1:
transform_func: msmarco_query_type_to_tsv
transform_params:
data_frac : 0.2
read_file_names:
- train_v2.1.json
- dev_v2.1.json
- eval_v2.1_public.json

read_dir: msmarco_qna_data
save_dir: ../../data
6 changes: 4 additions & 2 deletions utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"seqeval_recall" : seqeval_recall,
"snips_f1_score" : snips_f1_score,
"snips_precision" : snips_precision,
"snips_recall" : snips_recall
"snips_recall" : snips_recall,
"classification_recall" : classification_recall
}

TRANSFORM_FUNCS = {
Expand All @@ -37,7 +38,8 @@
"msmarco_query_type_to_tsv" : msmarco_query_type_to_tsv,
"imdb_sentiment_data_to_tsv" : imdb_sentiment_data_to_tsv,
"qqp_query_similarity_to_tsv" : qqp_query_similarity_to_tsv,
"msmarco_answerability_detection_to_tsv" : msmarco_answerability_detection_to_tsv
"msmarco_answerability_detection_to_tsv" : msmarco_answerability_detection_to_tsv,
"clinc_out_of_scope_to_tsv" : clinc_out_of_scope_to_tsv
}

class ModelType(IntEnum):
Expand Down
18 changes: 17 additions & 1 deletion utils/eval_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@

"""
File for creating metric functions
"""
from sklearn.metrics import accuracy_score, f1_score
from sklearn.metrics import recall_score as class_recall_score
from seqeval.metrics import f1_score as seq_f1
from seqeval.metrics import precision_score, recall_score

Expand Down Expand Up @@ -31,6 +34,19 @@ def classification_f1_score(yTrue, yPred):
"""
return f1_score(yTrue, yPred, average='micro')

def classification_recall(yTrue, yPred):
"""
Standard recall score from sklearn for classification tasks.
It takes a batch of predictions and labels.
To use this metric, add **classification_f1_score** into list of ``metrics`` in task file.
Args:
yPred (:obj:`list`) : [0, 2, 1, 3]
yTrue (:obj:`list`) : [0, 1, 2, 3]
"""
return class_recall_score(yTrue, yPred, average='micro')

def seqeval_f1_score(yTrue, yPred):
"""
Expand Down
90 changes: 87 additions & 3 deletions utils/tranform_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import os
import re
import json
import random
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
from statistics import median
from sklearn.model_selection import train_test_split
SEED = 42
Expand Down Expand Up @@ -440,12 +442,12 @@ def msmarco_query_type_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrain
#saving
print('number of samples in final data : ', len(dfKeep))
print('writing for file {} at {}'.format(readFile, wrtDir))
dfKeep.to_csv(os.path.join(wrtDir, 'querytype_{}.tsv'.format(readFile.split('.')[0])), sep='\t',
dfKeep.to_csv(os.path.join(wrtDir, 'querytype_{}.tsv'.format(readFile.lower().replace('.json', ''))), sep='\t',
index=False, header=False)
if isTrainFile:
allClasses = dfKeep['query_type'].unique()
labelMap = {lab : i for i, lab in enumerate(allClasses)}
labelMapPath = os.path.join(wrtDir, 'querytype_{}_label_map.joblib'.format(readFile.split('.')[0]))
labelMapPath = os.path.join(wrtDir, 'querytype_{}_label_map.joblib'.format(readFile.lower().replace('.json', '')))
joblib.dump(labelMap, labelMapPath)
print('Created label map file at', labelMapPath)

Expand Down Expand Up @@ -668,4 +670,86 @@ def msmarco_answerability_detection_to_tsv(dataDir, readFile, wrtDir, transParam
print('Dev file written at: ', os.path.join(wrtDir, 'msmarco_answerability_dev.tsv'))

devDf.to_csv(os.path.join(wrtDir, 'msmarco_answerability_test.tsv'), sep='\t', index=False, header=False)
print('Test file written at: ', os.path.join(wrtDir, 'msmarco_answerability_test.tsv'))
print('Test file written at: ', os.path.join(wrtDir, 'msmarco_answerability_test.tsv'))

def clinc_out_of_scope_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):

"""
For using this transform function, set ``transform_func`` : **clinc_out_of_scope_to_tsv** in transform file.
Args:
dataDir (:obj:`str`) : Path to the directory where the raw data files to be read are present..
readFile (:obj:`str`) : This is the file which is currently being read and transformed by the function.
wrtDir (:obj:`str`) : Path to the directory where to save the transformed tsv files.
transParamDict (:obj:`dict`, defaults to :obj:`None`): Dictionary requiring the following parameters as key-value
- ``samples_per_intent_train`` (defaults to 7) : Number of in-scope samples per intent to consider, as this data has imbalance for inscope and outscope
"""
transParamDict.setdefault("samples_per_intent_train", 7)

print("Making data from file {} ...".format(readFile))
raw = json.load(open(os.path.join(dataDir, readFile)))

print('Num of train samples in-scope: ', len(raw['train']))
inScopeTrain = defaultdict(list)
for sentence, intent in raw['train']:
inScopeTrain[intent].append(sentence)

#sampling
inscopeSampledTrain = []
numSamplesPerInt = 7
random.seed(SEED)
for intent in inScopeTrain:
inscopeSampledTrain += random.sample(inScopeTrain[intent], int(transParamDict["samples_per_intent_train"]))

print('Num of sampled train samples in-scope: ', len(inscopeSampledTrain))
#out of scope train
outscopeTrain = [sample[0] for sample in raw['oos_train']]
print('Num of train out-scope samples: ', len(outscopeTrain))

#train data
allTrain = inscopeSampledTrain + outscopeTrain
allTrainLabels = [1]*len(inscopeSampledTrain) + [0]*len(outscopeTrain)

#writing train data file
trainF = open(os.path.join(wrtDir, 'clinc_outofscope_train.tsv'), 'w')
for uid, (samp, lab) in enumerate(zip(allTrain, allTrainLabels)):
trainF.write("{}\t{}\t{}\n".format(uid, lab, samp))
print('Train file written at: ', os.path.join(wrtDir, 'clinc_outofscope_train.tsv'))
trainF.close()

#making dev set
inscopeDev = [sample[0] for sample in raw['val']]
outscopeDev = [sample[0] for sample in raw['oos_val']]
print('Num of val out-scope samples: ', len(outscopeDev))
print('Num of val in-scope samples: ', len(inscopeDev))

#allDev = inscopeDev + outscopeDev
allDev = outscopeDev
#allDevLabels = [1]*inscopeDev + [0]*outscopeDev
allDevLabels = [0]*len(outscopeDev)

#writing dev data file
devF = open(os.path.join(wrtDir, 'clinc_outofscope_dev.tsv'), 'w')
for uid, (samp, lab) in enumerate(zip(allDev, allDevLabels)):
devF.write("{}\t{}\t{}\n".format(uid, lab, samp))
print('Dev file written at: ', os.path.join(wrtDir, 'clinc_outofscope_dev.tsv'))
devF.close()

#making test set
inscopeTest = [sample[0] for sample in raw['test']]
outscopeTest = [sample[0] for sample in raw['oos_test']]
print('Num of test out-scope samples: ', len(outscopeTest))
print('Num of test in-scope samples: ', len(inscopeTest))

allTest = inscopeTest + outscopeTest
allTestLabels = [1]*len(inscopeTest) + [0]*len(outscopeTest)

#writing test data file
testF = open(os.path.join(wrtDir, 'clinc_outofscope_test.tsv'), 'w')
for uid, (samp, lab) in enumerate(zip(allTest, allTestLabels)):
testF.write("{}\t{}\t{}\n".format(uid, lab, samp))
print('Test file written at: ', os.path.join(wrtDir, 'clinc_outofscope_test.tsv'))
testF.close()

0 comments on commit ebefeaf

Please sign in to comment.