Skip to content

Commit

Permalink
Merge pull request sooftware#139 from hwiorn/fix-cpu-training-for-con…
Browse files Browse the repository at this point in the history
…former

Fix cpu training for conformer
  • Loading branch information
sooftware authored May 19, 2021
2 parents c535c2a + 06415c7 commit dcf9024
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 1 deletion.
2 changes: 2 additions & 0 deletions bin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def train(config: DictConfig) -> nn.DataParallel:
torch.manual_seed(config.train.seed)
torch.cuda.manual_seed_all(config.train.seed)
device = check_envirionment(config.train.use_cuda)
if config.train.num_threads and int(config.train.num_threads) > 0:
torch.set_num_threads(config.train.num_threads)

vocab = KsponSpeechVocabulary(
f'../../../data/vocab/aihub_{config.train.output_unit}_vocabs.csv',
Expand Down
2 changes: 1 addition & 1 deletion kospeech/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def build_conformer(
half_step_residual=half_step_residual,
device=device,
decoder=decoder,
))
)).to(device)


def build_deepspeech2(
Expand Down
3 changes: 3 additions & 0 deletions kospeech/models/conformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
d_model=encoder_dim,
num_heads=num_attention_heads,
dropout_p=attention_dropout_p,
device=device,
),
),
ResidualConnectionModule(
Expand All @@ -99,13 +100,15 @@ def __init__(
kernel_size=conv_kernel_size,
expansion_factor=conv_expansion_factor,
dropout_p=conv_dropout_p,
device=device,
),
),
ResidualConnectionModule(
module=FeedForwardModule(
encoder_dim=encoder_dim,
expansion_factor=feed_forward_expansion_factor,
dropout_p=feed_forward_dropout_p,
device=device,
),
module_factor=self.feed_forward_residual_factor,
),
Expand Down
1 change: 1 addition & 0 deletions kospeech/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class TrainConfig:

num_workers: int = 4
use_cuda: bool = True
num_threads: int = 2

init_lr_scale: float = 0.01
final_lr_scale: float = 0.05
Expand Down

0 comments on commit dcf9024

Please sign in to comment.