Skip to content

Commit

Permalink
transform functions fix
Browse files Browse the repository at this point in the history
  • Loading branch information
saransh-mehta committed Jun 15, 2020
1 parent 38dc0ce commit 0046faa
Showing 1 changed file with 91 additions and 1 deletion.
92 changes: 91 additions & 1 deletion utils/tranform_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,95 @@ def generate_ngram_sequences(data, seq_len_right, seq_len_left):
sequence_dict[key] = left_seq + right_seq
i += 1
return sequence_dict

def validate_sequences(sequence_dict, seq_len_right, seq_len_left):
micro_sequences = []
macro_sequences = {}

for key in sequence_dict.keys():
score = sequence_dict[key]

if score < 1 and len(key.split()) <= seq_len_right:
micro_sequences.append(key)
else:
macro_sequences[key] = score

non_frag_sequences = []
macro_sequences_copy = macro_sequences.copy()

for sent in tqdm(micro_sequences, total = len(micro_sequences)):
for key in macro_sequences.keys():
if sent in key:
non_frag_sequences.append(key)
del macro_sequences_copy[key]

macro_sequences = macro_sequences_copy.copy()

for sent in non_frag_sequences:
macro_sequences[sent] = 0

for sent in micro_sequences:
macro_sequences[sent] = 0

return macro_sequences

def create_fragment_detection_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):

"""
This function transforms data for fragment detection task (detecting whether a sentence is incomplete/fragment or not).
It takes data in single sentence classification format and creates fragment samples from the sentences.
In the transformed file, label 1 and 0 represent fragment and non-fragment sentence respectively.
Following transformed files are written at wrtDir
- Fragment transformed tsv file containing fragment/non-fragment sentences and labels
For using this transform function, set ``transform_func`` : **create_fragment_detection_tsv** in transform file.
Args:
dataDir (:obj:`str`) : Path to the directory where the raw data files to be read are present..
readFile (:obj:`str`) : This is the file which is currently being read and transformed by the function.
wrtDir (:obj:`str`) : Path to the directory where to save the transformed tsv files.
transParamDict (:obj:`dict`, defaults to :obj:`None`): Dictionary requiring the following parameters as key-value
- ``data_frac`` (defaults to 0.2) : Fraction of data to consider for making fragments.
- ``seq_len_right`` : (defaults to 3) : Right window length for making n-grams.
- ``seq_len_left`` (defaults to 2) : Left window length for making n-grams.
- ``sep`` (defaults to "\t") : column separator for input file.
- ``query_col`` (defaults to 2) : column number containing sentences. Counting starts from 0.
"""

transParamDict.setdefault("data_frac", 0.2)
transParamDict.setdefault("seq_len_right", 3)
transParamDict.setdefault("seq_len_left", 2)
transParamDict.setdefault("sep", "\t")
transParamDict.setdefault("query_col", 2)

allDataDf = pd.read_csv(os.path.join(dataDir, readFile), sep=transParamDict["sep"], header=None)
sampledDataDf = allDataDf.sample(frac = float(transParamDict['data_frac']), random_state=42)

#2nd column is considered to have queries in dataframe, 0th uid, 1st label
# making n-gram with left and right window
seqDict = generate_ngram_sequences(data = list(sampledDataDf.iloc[:, int(transParamDict["query_col"])]),
seq_len_right = transParamDict['seq_len_right'],
seq_len_left = transParamDict['seq_len_left'])

fragDict = validate_sequences(seqDict, seq_len_right = transParamDict['seq_len_right'],
seq_len_left = transParamDict['seq_len_left'])

finalDf = pd.DataFrame({'uid' : [i for i in range(len(fragDict)+len(allDataDf))],
'label' : [1]*len(fragDict)+[0]*len(allDataDf),
'query' : list(fragDict.keys())+list(allDataDf.iloc[:, int(transParamDict["query_col"]) ]) })

print('number of fragment samples : ', len(fragDict))
print('number of non-fragment samples : ', len(allDataDf))
# saving
print('writing fragment file for {} at {}'.format(readFile, wrtDir))

finalDf.to_csv(os.path.join(wrtDir, 'fragment_{}.tsv'.format(readFile.split('.')[0])), sep='\t',
index=False, header=False)


def msmarco_answerability_detection_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
"""
This function transforms the MSMARCO triples data available at `triples <https://msmarco.blob.core.windows.net/msmarcoranking/triples.train.small.tar.gz>`_
Expand Down Expand Up @@ -412,7 +501,7 @@ def msmarco_answerability_detection_to_tsv(dataDir, readFile, wrtDir, transParam

devDf.to_csv(os.path.join(wrtDir, 'msmarco_answerability_test.tsv'), sep='\t', index=False, header=False)
print('Test file written at: ', os.path.join(wrtDir, 'msmarco_answerability_test.tsv'))

def msmarco_query_type_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):

"""
Expand Down Expand Up @@ -458,6 +547,7 @@ def msmarco_query_type_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrain
labelMapPath = os.path.join(wrtDir, 'querytype_{}_label_map.joblib'.format(readFile.lower().replace('.json', '')))
joblib.dump(labelMap, labelMapPath)
print('Created label map file at', labelMapPath)


def imdb_sentiment_data_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):

Expand Down

0 comments on commit 0046faa

Please sign in to comment.