Skip to content

Commit

Permalink
Update several things
Browse files Browse the repository at this point in the history
  • Loading branch information
sooftware committed Jan 12, 2021
1 parent d625b55 commit 5211e80
Show file tree
Hide file tree
Showing 24 changed files with 81 additions and 309 deletions.
6 changes: 2 additions & 4 deletions bin/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.

import os
import sys
import hydra
import warnings
sys.path.append('..')

from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf, DictConfig
from kospeech.dataclass import EvalConfig, FilterBankConfig
from kospeech.evaluator import EvalConfig
from kospeech.data.audio import FilterBankConfig
from kospeech.vocabs.ksponspeech import KsponSpeechVocabulary
from kospeech.vocabs.librispeech import LibriSpeechVocabulary
from kospeech.data.label_loader import load_dataset
Expand Down
28 changes: 16 additions & 12 deletions bin/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,33 @@
import argparse
import torch
import torch.nn as nn
import sys
import numpy as np
import torchaudio
sys.path.append('..')
from torch import Tensor
from kospeech.models.deepspeech2.model import DeepSpeech2
from kospeech.models.las.model import ListenAttendSpell
from kospeech.vocabs.ksponspeech import KsponSpeechVocabulary
from kospeech.data.audio.core import load_audio
from kospeech.models import (
SpeechTransformer,
Jasper,
DeepSpeech2,
ListenAttendSpell,
)


def parse_audio(audio_path: str, del_silence: bool = False, audio_extension: str = 'pcm') -> Tensor:
signal = load_audio(audio_path, del_silence, extension=audio_extension)
feature_vector = torchaudio.compliance.kaldi.fbank(
feature = torchaudio.compliance.kaldi.fbank(
waveform=Tensor(signal).unsqueeze(0),
num_mel_bins=80,
frame_length=20,
frame_shift=10,
window_type='hamming'
).transpose(0, 1).numpy()

feature_vector -= feature_vector.mean()
feature_vector /= np.std(feature_vector)
feature -= feature.mean()
feature /= np.std(feature)

return torch.FloatTensor(feature_vector).transpose(0, 1)
return torch.FloatTensor(feature).transpose(0, 1)


parser = argparse.ArgumentParser(description='KoSpeech')
Expand All @@ -48,8 +50,8 @@ def parse_audio(audio_path: str, del_silence: bool = False, audio_extension: str
parser.add_argument('--device', type=str, require=False, default='cpu')
opt = parser.parse_args()

feature_vector = parse_audio(opt.audio_path, del_silence=True)
input_length = torch.IntTensor([len(feature_vector)])
feature = parse_audio(opt.audio_path, del_silence=True)
input_length = torch.LongTensor([len(feature)])
vocab = KsponSpeechVocabulary('data/vocab/aihub_character_vocabs.csv')

model = torch.load(opt.model_path, map_location=lambda storage, loc: storage).to(opt.device)
Expand All @@ -61,10 +63,12 @@ def parse_audio(audio_path: str, del_silence: bool = False, audio_extension: str
model.encoder.device = opt.device
model.decoder.device = opt.device

y_hats = model.greedy_search(feature_vector.unsqueeze(0), input_length, opt.device)
y_hats = model.greedy_search(feature.unsqueeze(0), input_length, opt.device)
elif isinstance(model, DeepSpeech2):
model.device = opt.device
y_hats = model.greedy_search(feature_vector.unsqueeze(0), input_length, opt.device)
y_hats = model.greedy_search(feature.unsqueeze(0), input_length, opt.device)
elif isinstance(model, SpeechTransformer) or isinstance(model, Jasper):
y_hats = model.greedy_search(feature.unsqueeze(0), input_length, opt.device)

sentence = vocab.label_to_string(y_hats.cpu().detach().numpy())
print(sentence)
24 changes: 15 additions & 9 deletions bin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import random
import warnings
import torch
import torch.nn as nn
import hydra

from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf, DictConfig
from kospeech.data.data_loader import split_dataset
Expand Down Expand Up @@ -57,7 +57,13 @@
)


def train(config: DictConfig):
KSPONSPEECH_VOCAB_PATH = '../../../data/vocab/kspon_sentencepiece.vocab'
KSPONSPEECH_SP_MODEL_PATH = '../../../data/vocab/kspon_sentencepiece.model'
LIBRISPEECH_VOCAB_PATH = '../../../data/vocab/tokenizer.vocab'
LIBRISPEECH_TOKENIZER_PATH = '../../../data/vocab/tokenizer.model'


def train(config: DictConfig) -> nn.DataParallel:
random.seed(config.train.seed)
torch.manual_seed(config.train.seed)
torch.cuda.manual_seed_all(config.train.seed)
Expand All @@ -66,19 +72,18 @@ def train(config: DictConfig):
if config.train.dataset == 'kspon':
if config.train.output_unit == 'subword':
vocab = KsponSpeechVocabulary(
vocab_path='../../../data/vocab/kspon_sentencepiece.vocab',
vocab_path=KSPONSPEECH_VOCAB_PATH,
output_unit=config.train.output_unit,
sp_model_path='../../../data/vocab/kspon_sentencepiece.model',
sp_model_path=KSPONSPEECH_SP_MODEL_PATH,
)
else:
vocab = KsponSpeechVocabulary(
f'../../../data/vocab/aihub_{config.train.output_unit}_vocabs.csv', output_unit=config.train.output_unit
f'../../../data/vocab/aihub_{config.train.output_unit}_vocabs.csv',
output_unit=config.train.output_unit,
)

elif config.train.dataset == 'libri':
vocab = LibriSpeechVocabulary(
'../../../data/vocab/tokenizer.vocab', '../../../data/vocab/tokenizer.model'
)
vocab = LibriSpeechVocabulary(LIBRISPEECH_VOCAB_PATH, LIBRISPEECH_TOKENIZER_PATH)

else:
raise ValueError("Unsupported Dataset : {0}".format(config.train.dataset))
Expand Down Expand Up @@ -158,7 +163,8 @@ def train(config: DictConfig):
def main(config: DictConfig) -> None:
warnings.filterwarnings('ignore')
logger.info(OmegaConf.to_yaml(config))
train(config)
last_model_checkpoint = train(config)
torch.save(last_model_checkpoint, os.path.join(os.getcwd(), "last_model_checkpoint.pt"))


if __name__ == '__main__':
Expand Down
14 changes: 0 additions & 14 deletions configs/audio/fbank.yaml

This file was deleted.

14 changes: 0 additions & 14 deletions configs/audio/melspectrogram.yaml

This file was deleted.

14 changes: 0 additions & 14 deletions configs/audio/mfcc.yaml

This file was deleted.

14 changes: 0 additions & 14 deletions configs/audio/spectrogram.yaml

This file was deleted.

3 changes: 0 additions & 3 deletions configs/eval.yaml

This file was deleted.

11 changes: 0 additions & 11 deletions configs/eval/default.yaml

This file was deleted.

12 changes: 0 additions & 12 deletions configs/model/ds2.yaml

This file was deleted.

20 changes: 0 additions & 20 deletions configs/model/joint-ctc-attention-las.yaml

This file was deleted.

15 changes: 0 additions & 15 deletions configs/model/joint-ctc-attention-transformer.yaml

This file was deleted.

18 changes: 0 additions & 18 deletions configs/model/las.yaml

This file was deleted.

12 changes: 0 additions & 12 deletions configs/model/transformer.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions configs/train.yaml

This file was deleted.

31 changes: 0 additions & 31 deletions configs/train/ds2_train.yaml

This file was deleted.

31 changes: 0 additions & 31 deletions configs/train/las_train.yaml

This file was deleted.

Loading

0 comments on commit 5211e80

Please sign in to comment.