Skip to content

Commit

Permalink
Add Paarth's HiFi-GAN and Tacotron fine-tuning code (NVIDIA#3000)
Browse files Browse the repository at this point in the history
Signed-off-by: Jocelyn Huang <[email protected]>
  • Loading branch information
redoctopus authored Oct 13, 2021
1 parent 38e74de commit 9bbbe2e
Show file tree
Hide file tree
Showing 8 changed files with 365 additions and 25 deletions.
2 changes: 2 additions & 0 deletions examples/tts/conf/fastpitch_align_44100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ model:
pitch_fmax: 640
pitch_avg: 211.27540199742586 # for a female speaker (8051 HiFiGAN)
pitch_std: 52.1851002822779 # for a female speaker (8051 HiFiGAN)
dur_loss_scale: 0.1
pitch_loss_scale: 0.1

train_ds:
dataset:
Expand Down
67 changes: 67 additions & 0 deletions examples/tts/conf/hifigan/hifigan44100.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
name: "HifiGan"
train_dataset: ???
validation_datasets: ???

defaults:
- model/generator: v4
- model/train_ds: train_ds
- model/validation_ds: val_ds

model:
preprocessor:
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
dither: 0.0
frame_splicing: 1
nfilt: 80
highfreq: null
log: true
log_zero_guard_type: clamp
log_zero_guard_value: 1e-05
lowfreq: 0
mag_power: 1.0
n_fft: 2048
n_window_size: 2048
n_window_stride: 512
normalize: null
pad_to: 0
pad_value: -11.52
preemph: null
sample_rate: 44100
window: hann
use_grads: false
exact_pad: true

optim:
_target_: torch.optim.AdamW
lr: 0.0002
betas: [0.8, 0.99]

sched:
name: CosineAnnealing
min_lr: 1e-5
warmup_ratio: 0.02

max_steps: 25000000
l1_loss_factor: 45
denoise_strength: 0.0025

trainer:
gpus: -1 # number of gpus
max_steps: ${model.max_steps}
num_nodes: 1
accelerator: ddp
accumulate_grad_batches: 1
checkpoint_callback: False # Provided by exp_manager
logger: False # Provided by exp_manager
flush_logs_every_n_steps: 200
log_every_n_steps: 100
check_val_every_n_epoch: 10

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: "val_loss"
mode: "min"
8 changes: 8 additions & 0 deletions examples/tts/conf/hifigan/model/generator/v4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# @package _group_
_target_: nemo.collections.tts.modules.hifigan_modules.Generator
resblock: 1
upsample_rates: [8,8,4,2]
upsample_kernel_sizes: [16,16,4,4]
upsample_initial_channel: 512
resblock_kernel_sizes: [3,7,11]
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
175 changes: 175 additions & 0 deletions examples/tts/conf/tacotron2_44100.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
name: Tacotron2
sample_rate: 44100
# <PAD>, <BOS>, <EOS> will be added by the tacotron2.py script
labels:
- ' '
- '!'
- '"'
- ''''
- (
- )
- ','
- '-'
- .
- ':'
- ;
- '?'
- a
- b
- c
- d
- e
- f
- g
- h
- i
- j
- k
- l
- m
- 'n'
- o
- p
- q
- r
- s
- t
- u
- v
- w
- x
- 'y'
- z
n_fft: 2048
n_mels: 80
fmax: null
n_stride: 512
pad_value: -11.52
train_dataset: ???
validation_datasets: ???

model:
labels: ${labels}
train_ds:
dataset:
_target_: "nemo.collections.asr.data.audio_to_text.AudioToCharDataset"
manifest_filepath: ${train_dataset}
max_duration: null
min_duration: 0.1
trim: false
int_values: false
normalize: true
sample_rate: ${sample_rate}
# bos_id: 66
# eos_id: 67
# pad_id: 68 These parameters are added automatically in Tacotron2
dataloader_params:
drop_last: false
shuffle: true
batch_size: 48
num_workers: 4


validation_ds:
dataset:
_target_: "nemo.collections.asr.data.audio_to_text.AudioToCharDataset"
manifest_filepath: ${validation_datasets}
max_duration: null
min_duration: 0.1
int_values: false
normalize: true
sample_rate: ${sample_rate}
trim: false
# bos_id: 66
# eos_id: 67
# pad_id: 68 These parameters are added automatically in Tacotron2
dataloader_params:
drop_last: false
shuffle: false
batch_size: 48
num_workers: 8

preprocessor:
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
dither: 0.0
nfilt: ${n_mels}
frame_splicing: 1
highfreq: ${fmax}
log: true
log_zero_guard_type: clamp
log_zero_guard_value: 1e-05
lowfreq: 0
mag_power: 1.0
n_fft: ${n_fft}
n_window_size: 2048
n_window_stride: ${n_stride}
normalize: null
pad_to: 16
pad_value: ${pad_value}
preemph: null
sample_rate: ${sample_rate}
window: hann

encoder:
_target_: nemo.collections.tts.modules.tacotron2.Encoder
encoder_kernel_size: 5
encoder_n_convolutions: 3
encoder_embedding_dim: 512

decoder:
_target_: nemo.collections.tts.modules.tacotron2.Decoder
decoder_rnn_dim: 1024
encoder_embedding_dim: ${model.encoder.encoder_embedding_dim}
gate_threshold: 0.5
max_decoder_steps: 1000
n_frames_per_step: 1 # currently only 1 is supported
n_mel_channels: ${n_mels}
p_attention_dropout: 0.1
p_decoder_dropout: 0.1
prenet_dim: 256
prenet_p_dropout: 0.5
# Attention parameters
attention_dim: 128
attention_rnn_dim: 1024
# AttentionLocation Layer parameters
attention_location_kernel_size: 31
attention_location_n_filters: 32
early_stopping: true

postnet:
_target_: nemo.collections.tts.modules.tacotron2.Postnet
n_mel_channels: ${n_mels}
p_dropout: 0.5
postnet_embedding_dim: 512
postnet_kernel_size: 5
postnet_n_convolutions: 5

optim:
name: adam
lr: 1e-3
weight_decay: 1e-6

# scheduler setup
sched:
name: CosineAnnealing
min_lr: 1e-5


trainer:
gpus: 1 # number of gpus
max_epochs: ???
num_nodes: 1
accelerator: ddp
accumulate_grad_batches: 1
checkpoint_callback: False # Provided by exp_manager
logger: False # Provided by exp_manager
gradient_clip_val: 1.0
flush_logs_every_n_steps: 1000
log_every_n_steps: 200
check_val_every_n_epoch: 25

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: True
create_checkpoint_callback: True
32 changes: 32 additions & 0 deletions examples/tts/hifigan_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytorch_lightning as pl

from nemo.collections.tts.models import HifiGanModel
from nemo.core.config import hydra_runner
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="conf/hifigan", config_name="hifigan44100")
def main(cfg):
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = HifiGanModel(cfg=cfg.model, trainer=trainer)
model.maybe_init_from_pretrained_checkpoint(cfg=cfg)
trainer.fit(model)


if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
45 changes: 45 additions & 0 deletions examples/tts/tacotron2_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytorch_lightning as pl

from nemo.collections.common.callbacks import LogEpochTimeCallback
from nemo.collections.tts.models import Tacotron2Model
from nemo.core.config import hydra_runner
from nemo.utils.exp_manager import exp_manager


# hydra_runner is a thin NeMo wrapper around Hydra
# It looks for a config named tacotron2.yaml inside the conf folder
# Hydra parses the yaml and returns it as a Omegaconf DictConfig
@hydra_runner(config_path="conf", config_name="tacotron2_44100")
def main(cfg):
# Define the Lightning trainer
trainer = pl.Trainer(**cfg.trainer)
# exp_manager is a NeMo construct that helps with logging and checkpointing
exp_manager(trainer, cfg.get("exp_manager", None))
# Define the Tacotron 2 model, this will construct the model as well as
# define the training and validation dataloaders
model = Tacotron2Model(cfg=cfg.model, trainer=trainer)
model.maybe_init_from_pretrained_checkpoint(cfg=cfg)
# Let's add a few more callbacks
lr_logger = pl.callbacks.LearningRateMonitor()
epoch_time_logger = LogEpochTimeCallback()
trainer.callbacks.extend([lr_logger, epoch_time_logger])
# Call lightning trainer's fit() to train the model
trainer.fit(model)


if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
10 changes: 8 additions & 2 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,14 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.log_train_images = False
self.mel_loss = MelLoss()
loss_scale = 0.1 if self.learn_alignment else 1.0
self.pitch_loss = PitchLoss(loss_scale=loss_scale)
self.duration_loss = DurationLoss(loss_scale=loss_scale)
dur_loss_scale = loss_scale
pitch_loss_scale = loss_scale
if "dur_loss_scale" in cfg:
dur_loss_scale = cfg.dur_loss_scale
if "pitch_loss_scale" in cfg:
pitch_loss_scale = cfg.pitch_loss_scale
self.pitch_loss = PitchLoss(loss_scale=pitch_loss_scale)
self.duration_loss = DurationLoss(loss_scale=dur_loss_scale)
input_fft_kwargs = {}
if self.learn_alignment:
self.aligner = instantiate(self._cfg.alignment_module)
Expand Down
Loading

0 comments on commit 9bbbe2e

Please sign in to comment.