Skip to content

Commit

Permalink
Fix several things
Browse files Browse the repository at this point in the history
  • Loading branch information
sooftware committed Jan 5, 2021
1 parent cf65804 commit a109c5a
Show file tree
Hide file tree
Showing 12 changed files with 268 additions and 213 deletions.
26 changes: 15 additions & 11 deletions bin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
LibriSpeechVocabulary,
)
from kospeech.dataclasses import (
FBankConfig,
FilterBankConfig,
MelSpectrogramConfig,
MFCCConfig,
MfccConfig,
SpectrogramConfig,
TrainConfig,
DeepSpeech2TrainConfig,
ListenAttendSpellTrainConfig,
TransformerTrainConfig,
DeepSpeech2Config,
JointCTCAttentionLASConfig,
ListenAttendSpellConfig,
Expand All @@ -51,9 +53,11 @@ def train(config: DictConfig):

if config.train.dataset == 'kspon':
if config.train.output_unit == 'subword':
vocab = KsponSpeechVocabulary(vocab_path='../../../data/vocab/kspon_sentencepiece.vocab',
output_unit=config.train.output_unit,
sp_model_path='../../../data/vocab/kspon_sentencepiece.model')
vocab = KsponSpeechVocabulary(
vocab_path='../../../data/vocab/kspon_sentencepiece.vocab',
output_unit=config.train.output_unit,
sp_model_path='../../../data/vocab/kspon_sentencepiece.model',
)
else:
vocab = KsponSpeechVocabulary(
f'../../../data/vocab/aihub_{config.train.output_unit}_vocabs.csv', output_unit=config.train.output_unit
Expand Down Expand Up @@ -122,13 +126,13 @@ def train(config: DictConfig):


cs = ConfigStore.instance()
cs.store(group="audio", name="fbank", node=FBankConfig, package="audio")
cs.store(group="audio", name="fbank", node=FilterBankConfig, package="audio")
cs.store(group="audio", name="melspectrogram", node=MelSpectrogramConfig, package="audio")
cs.store(group="audio", name="mfcc", node=MFCCConfig, package="audio")
cs.store(group="audio", name="mfcc", node=MfccConfig, package="audio")
cs.store(group="audio", name="spectrogram", node=SpectrogramConfig, package="audio")
cs.store(group="train", name="ds2_train", node=TrainConfig, package="train")
cs.store(group="train", name="las_train", node=TrainConfig, package="train")
cs.store(group="train", name="transformer_train", node=TrainConfig, package="train")
cs.store(group="train", name="ds2_train", node=DeepSpeech2TrainConfig, package="train")
cs.store(group="train", name="las_train", node=ListenAttendSpellTrainConfig, package="train")
cs.store(group="train", name="transformer_train", node=TransformerTrainConfig, package="train")
cs.store(group="model", name="ds2", node=DeepSpeech2Config, package="model")
cs.store(group="model", name="las", node=ListenAttendSpellConfig, package="model")
cs.store(group="model", name="transformer", node=TransformerConfig, package="model")
Expand Down
4 changes: 2 additions & 2 deletions configs/train.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defaults:
- audio: fbank
- model: joint-ctc-attention-las
- train: las_train
- model: joint-ctc-attention-transformer
- train: transformer_train
2 changes: 0 additions & 2 deletions kospeech/checkpoint/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,11 @@ class Checkpoint(object):
epoch (int): current epoch (an epoch is a loop through the full training data)
Attributes:
CHECKPOINT_DIR_NAME (str): name of the checkpoint directory
SAVE_PATH (str): path of file to save
TRAINER_STATE_NAME (str): name of the file storing trainer states
MODEL_NAME (str): name of the file storing model
"""

CHECKPOINT_DIR_NAME = 'checkpoints'
SAVE_PATH = 'outputs'
TRAINER_STATE_NAME = 'trainer_states.pt'
MODEL_NAME = 'model.pt'
Expand Down
163 changes: 0 additions & 163 deletions kospeech/dataclasses.py

This file was deleted.

25 changes: 25 additions & 0 deletions kospeech/dataclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
# Soohwan Kim, Seyoung Bae, Cheolhwang Won.
# @ArXiv : KoSpeech: Open-Source Toolkit for End-to-End Korean Speech Recognition
# This source code is licensed under the Apache 2.0 License license found in the
# LICENSE file in the root directory of this source tree.

from .evaluate import EvalConfig
from .audio import (
MfccConfig,
MelSpectrogramConfig,
SpectrogramConfig,
FilterBankConfig,
)
from .model import (
DeepSpeech2Config,
ListenAttendSpellConfig,
TransformerConfig,
JointCTCAttentionLASConfig,
JointCTCAttentionTransformerConfig,
)
from .train import (
DeepSpeech2TrainConfig,
ListenAttendSpellTrainConfig,
TransformerTrainConfig,
)
50 changes: 50 additions & 0 deletions kospeech/dataclasses/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# Soohwan Kim, Seyoung Bae, Cheolhwang Won.
# @ArXiv : KoSpeech: Open-Source Toolkit for End-to-End Korean Speech Recognition
# This source code is licensed under the Apache 2.0 License license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass


@dataclass
class AudioConfig:
audio_extension: str = "pcm"
sample_rate: int = 16000
frame_length: int = 20
frame_shift: int = 10
normalize: bool = True
del_silence: bool = True
feature_extract_by: str = "kaldi"
time_mask_num: int = 4
freq_mask_num: int = 2
spec_augment: bool = True
input_reverse: bool = False


@dataclass
class FilterBankConfig(AudioConfig):
transform_method: str = "fbank"
n_mels: int = 80
freq_mask_para: int = 18


@dataclass
class MelSpectrogramConfig(AudioConfig):
transform_method: str = "mel"
n_mels: int = 80
freq_mask_para: int = 18


@dataclass
class MfccConfig(AudioConfig):
transform_method: str = "mfcc"
n_mels: int = 40
freq_mask_para: int = 8


@dataclass
class SpectrogramConfig(AudioConfig):
transform_method: str = "spectrogram"
n_mels: int = 161 # Not used
freq_mask_para: int = 24
22 changes: 22 additions & 0 deletions kospeech/dataclasses/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
# Soohwan Kim, Seyoung Bae, Cheolhwang Won.
# @ArXiv : KoSpeech: Open-Source Toolkit for End-to-End Korean Speech Recognition
# This source code is licensed under the Apache 2.0 License license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass


@dataclass
class EvalConfig:
dataset: str = 'kspon'
dataset_path: str = ''
transcript_path: str = '../../../data/eval_transcript.txt'
model_path: str = ''
output_unit: str = 'character'
batch_size: int = 32
num_workers: int = 4
print_every: int = 20
decode: str = 'greedy'
k: int = 3
use_cuda: bool = True
Loading

0 comments on commit a109c5a

Please sign in to comment.