forked from mozilla/DeepSpeech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
107 lines (75 loc) · 3.51 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import absolute_import, division, print_function
import os
import tensorflow as tf
from attrdict import AttrDict
from xdg import BaseDirectory as xdg
from util.flags import FLAGS
from util.gpu import get_available_gpus
from util.logging import log_error
from util.text import Alphabet
class ConfigSingleton:
_config = None
def __getattr__(self, name):
if not ConfigSingleton._config:
raise RuntimeError("Global configuration not yet initialized.")
if not hasattr(ConfigSingleton._config, name):
raise RuntimeError("Configuration option {} not found in config.".format(name))
return ConfigSingleton._config[name]
Config = ConfigSingleton() # pylint: disable=invalid-name
def initialize_globals():
c = AttrDict()
# CPU device
c.cpu_device = '/cpu:0'
# Available GPU devices
c.available_devices = get_available_gpus()
# If there is no GPU available, we fall back to CPU based operation
if not c.available_devices:
c.available_devices = [c.cpu_device]
# Set default dropout rates
if FLAGS.dropout_rate2 < 0:
FLAGS.dropout_rate2 = FLAGS.dropout_rate
if FLAGS.dropout_rate3 < 0:
FLAGS.dropout_rate3 = FLAGS.dropout_rate
if FLAGS.dropout_rate6 < 0:
FLAGS.dropout_rate6 = FLAGS.dropout_rate
# Set default checkpoint dir
if not FLAGS.checkpoint_dir:
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech', 'checkpoints'))
if FLAGS.load not in ['last', 'best', 'init', 'auto']:
FLAGS.load = 'auto'
# Set default summary dir
if not FLAGS.summary_dir:
FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech', 'summaries'))
# Standard session configuration that'll be used for all new sessions.
c.session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement,
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)
c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))
# Geometric Constants
# ===================
# For an explanation of the meaning of the geometric constants, please refer to
# doc/Geometry.md
# Number of MFCC features
c.n_input = 26 # TODO: Determine this programmatically from the sample rate
# The number of frames in the context
c.n_context = 9 # TODO: Determine the optimal value using a validation data set
# Number of units in hidden layers
c.n_hidden = FLAGS.n_hidden
c.n_hidden_1 = c.n_hidden
c.n_hidden_2 = c.n_hidden
c.n_hidden_5 = c.n_hidden
# LSTM cell state dimension
c.n_cell_dim = c.n_hidden
# The number of units in the third layer, which feeds in to the LSTM
c.n_hidden_3 = c.n_cell_dim
# Units in the sixth layer = number of characters in the target language plus one
c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label
# Size of audio window in samples
c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len / 1000)
# Stride for feature computations in samples
c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000)
if FLAGS.one_shot_infer:
if not os.path.exists(FLAGS.one_shot_infer):
log_error('Path specified in --one_shot_infer is not a valid file.')
exit(1)
ConfigSingleton._config = c # pylint: disable=protected-access