forked from mozilla/DeepSpeech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
355 lines (273 loc) · 12.4 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import os
import platform
#The API is not snake case which triggers linter errors
#pylint: disable=invalid-name
if platform.system().lower() == "windows":
dslib_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'lib')
# On Windows, we can't rely on RPATH being set to $ORIGIN/lib/ or on
# @loader_path/lib
if hasattr(os, 'add_dll_directory'):
# Starting with Python 3.8 this properly handles the problem
os.add_dll_directory(dslib_path)
else:
# Before Pythin 3.8 we need to change the PATH to include the proper
# directory for the dynamic linker
os.environ['PATH'] = dslib_path + ';' + os.environ['PATH']
import deepspeech
# rename for backwards compatibility
from deepspeech.impl import Version as version
class Model(object):
"""
Class holding a DeepSpeech model
:param aModelPath: Path to model file to load
:type aModelPath: str
"""
def __init__(self, model_path):
# make sure the attribute is there if CreateModel fails
self._impl = None
status, impl = deepspeech.impl.CreateModel(model_path)
if status != 0:
raise RuntimeError("CreateModel failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
self._impl = impl
def __del__(self):
if self._impl:
deepspeech.impl.FreeModel(self._impl)
self._impl = None
def beamWidth(self):
"""
Get beam width value used by the model. If setModelBeamWidth was not
called before, will return the default value loaded from the model file.
:return: Beam width value used by the model.
:type: int
"""
return deepspeech.impl.GetModelBeamWidth(self._impl)
def setBeamWidth(self, beam_width):
"""
Set beam width value used by the model.
:param beam_width: The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
:type beam_width: int
:return: Zero on success, non-zero on failure.
:type: int
"""
return deepspeech.impl.SetModelBeamWidth(self._impl, beam_width)
def sampleRate(self):
"""
Return the sample rate expected by the model.
:return: Sample rate.
:type: int
"""
return deepspeech.impl.GetModelSampleRate(self._impl)
def enableExternalScorer(self, scorer_path):
"""
Enable decoding using an external scorer.
:param scorer_path: The path to the external scorer file.
:type scorer_path: str
:throws: RuntimeError on error
"""
status = deepspeech.impl.EnableExternalScorer(self._impl, scorer_path)
if status != 0:
raise RuntimeError("EnableExternalScorer failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
def disableExternalScorer(self):
"""
Disable decoding using an external scorer.
:return: Zero on success, non-zero on failure.
"""
return deepspeech.impl.DisableExternalScorer(self._impl)
def addHotWord(self, word, boost):
"""
Add a word and its boost for decoding.
Words that don't occur in the scorer (e.g. proper nouns) or strings that contain spaces won't be taken into account.
:param word: the hot-word
:type word: str
:param boost: Positive boost value increases and negative reduces chance of a word occuring in a transcription. Excessive positive boost might lead to splitting up of letters of the word following the hot-word.
:type boost: float
:throws: RuntimeError on error
"""
status = deepspeech.impl.AddHotWord(self._impl, word, boost)
if status != 0:
raise RuntimeError("AddHotWord failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
def eraseHotWord(self, word):
"""
Remove entry for word from hot-words dict.
:param word: the hot-word
:type word: str
:throws: RuntimeError on error
"""
status = deepspeech.impl.EraseHotWord(self._impl, word)
if status != 0:
raise RuntimeError("EraseHotWord failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
def clearHotWords(self):
"""
Remove all entries from hot-words dict.
:throws: RuntimeError on error
"""
status = deepspeech.impl.ClearHotWords(self._impl)
if status != 0:
raise RuntimeError("ClearHotWords failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
def setScorerAlphaBeta(self, alpha, beta):
"""
Set hyperparameters alpha and beta of the external scorer.
:param alpha: The alpha hyperparameter of the decoder. Language model weight.
:type alpha: float
:param beta: The beta hyperparameter of the decoder. Word insertion weight.
:type beta: float
:return: Zero on success, non-zero on failure.
:type: int
"""
return deepspeech.impl.SetScorerAlphaBeta(self._impl, alpha, beta)
def stt(self, audio_buffer):
"""
Use the DeepSpeech model to perform Speech-To-Text.
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
:type audio_buffer: numpy.int16 array
:return: The STT result.
:type: str
"""
return deepspeech.impl.SpeechToText(self._impl, audio_buffer)
def sttWithMetadata(self, audio_buffer, num_results=1):
"""
Use the DeepSpeech model to perform Speech-To-Text and return results including metadata.
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
:type audio_buffer: numpy.int16 array
:param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this.
:type num_results: int
:return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information.
:type: :func:`Metadata`
"""
return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer, num_results)
def createStream(self):
"""
Create a new streaming inference state. The streaming state returned by
this function can then be passed to :func:`feedAudioContent()` and :func:`finishStream()`.
:return: Stream object representing the newly created stream
:type: :func:`Stream`
:throws: RuntimeError on error
"""
status, ctx = deepspeech.impl.CreateStream(self._impl)
if status != 0:
raise RuntimeError("CreateStream failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
return Stream(ctx)
class Stream(object):
"""
Class wrapping a DeepSpeech stream. The constructor cannot be called directly.
Use :func:`Model.createStream()`
"""
def __init__(self, native_stream):
self._impl = native_stream
def __del__(self):
if self._impl:
self.freeStream()
def feedAudioContent(self, audio_buffer):
"""
Feed audio samples to an ongoing streaming inference.
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
:type audio_buffer: numpy.int16 array
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to feed an already finished stream?")
deepspeech.impl.FeedAudioContent(self._impl, audio_buffer)
def intermediateDecode(self):
"""
Compute the intermediate decoding of an ongoing streaming inference.
:return: The STT intermediate result.
:type: str
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?")
return deepspeech.impl.IntermediateDecode(self._impl)
def intermediateDecodeWithMetadata(self, num_results=1):
"""
Compute the intermediate decoding of an ongoing streaming inference and return results including metadata.
:param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this.
:type num_results: int
:return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information.
:type: :func:`Metadata`
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?")
return deepspeech.impl.IntermediateDecodeWithMetadata(self._impl, num_results)
def finishStream(self):
"""
Compute the final decoding of an ongoing streaming inference and return
the result. Signals the end of an ongoing streaming inference. The underlying
stream object must not be used after this method is called.
:return: The STT result.
:type: str
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?")
result = deepspeech.impl.FinishStream(self._impl)
self._impl = None
return result
def finishStreamWithMetadata(self, num_results=1):
"""
Compute the final decoding of an ongoing streaming inference and return
results including metadata. Signals the end of an ongoing streaming
inference. The underlying stream object must not be used after this
method is called.
:param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this.
:type num_results: int
:return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information.
:type: :func:`Metadata`
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?")
result = deepspeech.impl.FinishStreamWithMetadata(self._impl, num_results)
self._impl = None
return result
def freeStream(self):
"""
Destroy a streaming state without decoding the computed logits. This can
be used if you no longer need the result of an ongoing streaming inference.
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to free an already finished stream?")
deepspeech.impl.FreeStream(self._impl)
self._impl = None
# This is only for documentation purpose
# Metadata, CandidateTranscript and TokenMetadata should be in sync with native_client/deepspeech.h
class TokenMetadata(object):
"""
Stores each individual character, along with its timing information
"""
def text(self):
"""
The text for this token
"""
def timestep(self):
"""
Position of the token in units of 20ms
"""
def start_time(self):
"""
Position of the token in seconds
"""
class CandidateTranscript(object):
"""
Stores the entire CTC output as an array of character metadata objects
"""
def tokens(self):
"""
List of tokens
:return: A list of :func:`TokenMetadata` elements
:type: list
"""
def confidence(self):
"""
Approximated confidence value for this transcription. This is roughly the
sum of the acoustic model logit values for each timestep/character that
contributed to the creation of this transcription.
"""
class Metadata(object):
def transcripts(self):
"""
List of candidate transcripts
:return: A list of :func:`CandidateTranscript` objects
:type: list
"""