forked from coqui-ai/STT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtranscribe.py
executable file
·247 lines (225 loc) · 9.15 KB
/
transcribe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
import os
import sys
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow.compat.v1.logging as tflogging
import tensorflow as tf
tflogging.set_verbosity(tflogging.ERROR)
import logging
logging.getLogger("sox").setLevel(logging.ERROR)
import glob
from multiprocessing import Process, cpu_count
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
from coqui_stt_training.util.feeding import split_audio_file
from coqui_stt_training.util.flags import FLAGS, create_flags
from coqui_stt_training.util.logging import (
create_progressbar,
log_error,
log_info,
log_progress,
)
def fail(message, code=1):
log_error(message)
sys.exit(code)
def transcribe_file(audio_path, tlog_path):
from coqui_stt_training.train import ( # pylint: disable=cyclic-import,import-outside-toplevel
create_model,
)
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
initialize_globals_from_cli()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path:
data_set = split_audio_file(
wav_path,
batch_size=FLAGS.batch_size,
aggressiveness=FLAGS.vad_aggressiveness,
outlier_duration_ms=FLAGS.outlier_duration_ms,
outlier_batch_size=FLAGS.outlier_batch_size,
)
iterator = tf.data.Iterator.from_structure(
data_set.output_types,
data_set.output_shapes,
output_classes=data_set.output_classes,
)
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
no_dropout = [None] * 6
logits, _ = create_model(
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session)
session.run(iterator.make_initializer(data_set))
transcripts = []
while True:
try:
starts, ends, batch_logits, batch_lengths = session.run(
[batch_time_start, batch_time_end, transposed, batch_x_len]
)
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(
batch_logits,
batch_lengths,
Config.alphabet,
FLAGS.beam_width,
num_processes=num_processes,
scorer=scorer,
)
decoded = list(d[0][1] for d in decoded)
transcripts.extend(zip(starts, ends, decoded))
transcripts.sort(key=lambda t: t[0])
transcripts = [
{"start": int(start), "end": int(end), "transcript": transcript}
for start, end, transcript in transcripts
]
with open(tlog_path, "w") as tlog_file:
json.dump(transcripts, tlog_file, default=float)
def transcribe_many(src_paths, dst_paths):
pbar = create_progressbar(
prefix="Transcribing files | ", max_value=len(src_paths)
).start()
for i in range(len(src_paths)):
p = Process(target=transcribe_file, args=(src_paths[i], dst_paths[i]))
p.start()
p.join()
log_progress(
'Transcribed file {} of {} from "{}" to "{}"'.format(
i + 1, len(src_paths), src_paths[i], dst_paths[i]
)
)
pbar.update(i)
pbar.finish()
def transcribe_one(src_path, dst_path):
transcribe_file(src_path, dst_path)
log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path))
def resolve(base_path, spec_path):
if spec_path is None:
return None
if not os.path.isabs(spec_path):
spec_path = os.path.join(base_path, spec_path)
return spec_path
def main(_):
if not FLAGS.src or not os.path.exists(FLAGS.src):
# path not given or non-existant
fail(
"You have to specify which file or catalog to transcribe via the --src flag."
)
else:
# path given and exists
src_path = os.path.abspath(FLAGS.src)
if os.path.isfile(src_path):
if src_path.endswith(".catalog"):
# Transcribe batch of files via ".catalog" file (from DSAlign)
if FLAGS.dst:
fail("Parameter --dst not supported if --src points to a catalog")
catalog_dir = os.path.dirname(src_path)
with open(src_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
for e in catalog_entries
]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail("Missing source file(s) in catalog")
if not FLAGS.force and any(
map(lambda e: os.path.isfile(e[1]), catalog_entries)
):
fail(
"Destination file(s) from catalog already existing, use --force for overwriting"
)
if any(
map(
lambda e: not os.path.isdir(os.path.dirname(e[1])),
catalog_entries,
)
):
fail("Missing destination directory for at least one catalog entry")
src_paths, dst_paths = zip(*paths)
transcribe_many(src_paths, dst_paths)
else:
# Transcribe one file
dst_path = (
os.path.abspath(FLAGS.dst)
if FLAGS.dst
else os.path.splitext(src_path)[0] + ".tlog"
)
if os.path.isfile(dst_path):
if FLAGS.force:
transcribe_one(src_path, dst_path)
else:
fail(
'Destination file "{}" already existing - use --force for overwriting'.format(
dst_path
),
code=0,
)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail("Missing destination directory")
elif os.path.isdir(src_path):
# Transcribe all files in dir
print("Transcribing all WAV files in --src")
if FLAGS.dst:
fail("Destination file not supported for batch decoding jobs.")
else:
if not FLAGS.recursive:
print(
"If you wish to recursively scan --src, then you must use --recursive"
)
wav_paths = glob.glob(src_path + "/*.wav")
else:
wav_paths = glob.glob(src_path + "/**/*.wav")
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
transcribe_many(wav_paths, dst_paths)
if __name__ == "__main__":
create_flags()
tf.app.flags.DEFINE_string(
"src",
"",
"Source path to an audio file or directory or catalog file."
"Catalog files should be formatted from DSAlign. A directory will"
"be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be "
"written in-place using the source filenames with "
'suffix ".tlog" instead of ".wav".',
)
tf.app.flags.DEFINE_string(
"dst",
"",
"path for writing the transcription log or logs (.tlog). "
"If --src is a directory, this one also has to be a directory "
"and the required sub-dir tree of --src will get replicated.",
)
tf.app.flags.DEFINE_boolean("recursive", False, "scan dir of audio recursively")
tf.app.flags.DEFINE_boolean(
"force",
False,
"Forces re-transcribing and overwriting of already existing "
"transcription logs (.tlog)",
)
tf.app.flags.DEFINE_integer(
"vad_aggressiveness",
3,
"How aggressive (0=lowest, 3=highest) the VAD should " "split audio",
)
tf.app.flags.DEFINE_integer("batch_size", 40, "Default batch size")
tf.app.flags.DEFINE_float(
"outlier_duration_ms",
10000,
"Duration in ms after which samples are considered outliers",
)
tf.app.flags.DEFINE_integer(
"outlier_batch_size", 1, "Batch size for duration outliers (defaults to 1)"
)
tf.app.run(main)