Skip to content

Commit

Permalink
add mother language column
Browse files Browse the repository at this point in the history
  • Loading branch information
a2d8a4v committed Jul 12, 2022
1 parent abac1af commit cdcd234
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
6 changes: 5 additions & 1 deletion prepare_scales.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] ; then
--input_text_file_path $data_dir/text \
--input_spk2utt_file_path $data_dir/spk2utt \
--input_cefr_label_file_path $data_dir/scale \
--input_spk2momlang_file_path $data_dir/momlanguage \
--output_text_file_path $data_dir/text.tsv \
--remove_filled_pauses $remove_filled_pauses \
--combine_same_speakerids $combine_same_speakerids
Expand All @@ -54,11 +55,13 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ] ; then
mkdir -pv data/${test_set} > /dev/null 2>&1
done

momlang_names=""
cefr_scores_names=""
text_names=""
sp2utt_names=""
for test_set in $test_sets; do
data_dir=$data_root/$test_set
momlang_names+="${data_root}/momlanguage "
cefr_scores_names+="${data_dir}/scale "
text_names+="${data_dir}/text "
sp2utt_names+="${data_dir}/spk2utt "
Expand All @@ -76,9 +79,10 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ] ; then
--input_json_file_path $dest_dir/all.json \
--input_text_file_path $data_dir/text \
--input_spk2utt_file_path $data_dir/spk2utt \
--input_spk2momlang_file_path $data_dir/momlanguage \
--input_cefr_label_file_path $data_dir/scale \
--output_text_file_path $data_dir/text.tsv \
--remove_filled_pauses $remove_filled_pauses \
--combine_same_speakerids $combine_same_speakerids
done
fi
fi
18 changes: 15 additions & 3 deletions utils/prepare_auto_grader_feats.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def argparse_function():
default='data/trn/text',
type=nullable_string)

parser.add_argument("--input_spk2momlang_file_path",
default='data/trn/momlanguage',
type=nullable_string)

parser.add_argument("--input_cefr_label_file_path",
default='CEFR_LABELS_PATH/trn_cefr_scores.txt',
type=str)
Expand Down Expand Up @@ -115,6 +119,7 @@ def xstr(s):
utt_text_dict = { utt_id:utt_info.get('stt') for utt_id, utt_info in utt_json_data.items() if xstr(utt_info.get('stt')).strip() }

utt_cefr_file_path_dict = open_utt2value(args.input_cefr_label_file_path)
utt_momlang_dict = open_utt2value(args.input_spk2momlang_file_path)

if remove_filled_pauses:
utt_text_dict = { utt_id:remove_tltschool_interregnum_tokens(texts) for utt_id, texts in utt_text_dict.items() }
Expand All @@ -123,49 +128,56 @@ def xstr(s):
utt_text_dict = { utt_id:remove_partial_words_call(texts) for utt_id, texts in utt_text_dict.items() }

if combine_same_speakerids:
assert input_spk2utt_file_path is not None, 'You need to point a specific path for input_spk2utt_file_path'

spk2utt_dict = opendict(input_spk2utt_file_path)
new_utt_text_dict = {}
new_utt_cefr_file_path_dict = {}
new_utt_momlang_dict = {}
for spk_id, utts_list in spk2utt_dict.items():
text_list = []
for utt_id in utts_list:
if utt_id in utt_text_dict: # BUG: some recognized result has empty result!
text_list.extend(utt_text_dict[utt_id].split())
new_utt_cefr_file_path_dict.setdefault(spk_id, utt_cefr_file_path_dict[utt_id])
new_utt_momlang_dict.setdefault(spk_id, utt_momlang_dict[utt_id])
if text_list: # BUG: some recognized result has empty result!
new_utt_text_dict[spk_id] = " ".join(text_list)
utt_text_dict = new_utt_text_dict
utt_cefr_file_path_dict = new_utt_cefr_file_path_dict
utt_momlang_dict = new_utt_momlang_dict

if get_specific_labels is not None:
count_cefr_labels = 0

sst = 0
max_seq_len = 0
with open(args.output_text_file_path, 'w') as f:
f.write("{}\t{}\t{}\n".format('score', 'sst', 'text'))
f.write("{}\t{}\t{}\t{}\n".format('score', 'sst', 'l1', 'text'))
for utt_or_spk_id, text in utt_text_dict.items():

if len(text.split()) > max_seq_len:
max_seq_len = len(text.split())

if get_specific_labels is not None:
if get_specific_labels.lower() == utt_cefr_file_path_dict[utt_or_spk_id].lower():
f.write("{}\t{}\t{}\n".format(
f.write("{}\t{}\t{}\t{}\n".format(
mapping_cefr2num(
utt_cefr_file_path_dict[utt_or_spk_id]
),
sst,
utt_momlang_dict[utt_or_spk_id],
text
)
)
count_cefr_labels+=1
else:
f.write("{}\t{}\t{}\n".format(
f.write("{}\t{}\t{}\t{}\n".format(
mapping_cefr2num(
utt_cefr_file_path_dict[utt_or_spk_id]
),
sst,
utt_momlang_dict[utt_or_spk_id],
text
)
)
Expand Down

0 comments on commit cdcd234

Please sign in to comment.