Skip to content

Commit

Permalink
modify get_tltschool_cefr_scales.py to callable file
Browse files Browse the repository at this point in the history
  • Loading branch information
a2d8a4v committed Jul 18, 2022
1 parent d65eceb commit eaff8c4
Showing 1 changed file with 37 additions and 13 deletions.
50 changes: 37 additions & 13 deletions utils/get_tltschool_cefr_scales.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,50 @@
## Functions
def opentext( file, col, filter ):
import argparse

def argparse_function():
parser = argparse.ArgumentParser()

parser.add_argument("--input_wavscp_file_path",
default='data/trn/text',
type=str)

parser.add_argument("--output_cefr_file_path",
default='CEFR_LABELS_PATH/trn_cefr_scores.txt',
type=str)


args = parser.parse_args()

return args

def open_utt2cefr( file, col, filter ):
s = {}
with open(file, "r") as f:
for l in f.readlines():
utt_id = l.split()[col]
if sum([1 for cefr in filter if cefr.lower() in utt_id.lower()]) == 0:
continue
utt_id_ = [ str(w).lower() for w in utt_id.split("-")[1].split("_") ]
utt_id_ = [ str(w) for w in utt_id.split("-")[1].split("_") ]
# speakerIp18_A2_002002001008-promptIp18_A2_en_22_20_103
if sum([1 for cefr in filter if cefr.lower() in utt_id_]) > 0:
if sum([1 for cefr in filter if cefr in utt_id_]) > 0:
s[utt_id] = utt_id_[1]
return s

if __name__ == '__main__':

## CEFR score filter
_filter = ["a1", "b1", "a2"]
# argparse
args = argparse_function()
input_wavscp_file_path = args.input_wavscp_file_path
output_cefr_file_path = args.output_cefr_file_path

## Data
_data = "/share/nas167/a2y3a1N0n2Yann/speechocean/espnet_amazon/egs/tlt-school/is2021_data-prep-all_baseline/data/train/text"
data = opentext( _data, 0, _filter )
# CEFR score filter
CEFR_filter = ["A1", "A2", "B1"]
utt2cefr_dict = open_utt2cefr(
input_wavscp_file_path,
0,
CEFR_filter
)

## Save
with open("/share/nas167/a2y3a1N0n2Yann/speechocean/espnet_amazon/egs/tlt-school/is2021_data-prep-all_baseline/data/train/text_cefr", "w") as f:
for utt_id, cefr in data.items():
f.write("{} {}\n".format(utt_id, cefr))
# save
with open(output_cefr_file_path, "w") as f:
for utt_id, cefr in utt2cefr_dict.items():
f.write("{} {}\n".format(utt_id, cefr))

0 comments on commit eaff8c4

Please sign in to comment.