From a3d5801c448b73e805a5b4067ba588b529f53c46 Mon Sep 17 00:00:00 2001 From: manmay nakhashi Date: Tue, 16 May 2023 04:28:21 +0530 Subject: [PATCH] Tortoise TTS inference (#2547) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial commit * Tortoise inference * revert path change * style fix * remove accidental remove * style fixes * style fixes * removed unwanted assests and deps * remove changes * remove cvvp * style fix black * added tortoise config and updated config and args, refactoring the code * added tortoise to api * Pull mel_norm from url * Use TTS cleaners * Let download model files * add ability to pass tortoise presets through coqui api * fix tests * fix style and tests * fix tts commandline for tortoise * Add config.json to tortoise * Use kwargs * Use regular model api for loading tortoise * Add load from dir to synthesizer * Fix Tortoise floats * Use model_dir when there are multiple urls * Use `synthesize` when exists * lint fixes and resolve preset bug * resolve a download bug and update model link * fix json * do tortoise inference from voice dir * fix * fix test * fix speaker id and remove assests * update inference_tests.yml * replace inference_test.yml * fix extra dir as None * fix tests * remove space * Reformat docstring * Add docs * Update docs * lint fixes --------- Co-authored-by: Eren Gölge Co-authored-by: Eren Gölge --- .github/workflows/inference_tests.yml | 2 +- TTS/.models.json | 20 + TTS/api.py | 30 +- TTS/bin/synthesize.py | 20 +- TTS/tts/configs/tortoise_config.py | 87 + TTS/tts/layers/tortoise/arch_utils.py | 433 +++++ TTS/tts/layers/tortoise/audio_utils.py | 177 ++ TTS/tts/layers/tortoise/autoregressive.py | 631 +++++++ TTS/tts/layers/tortoise/classifier.py | 144 ++ TTS/tts/layers/tortoise/clvp.py | 159 ++ TTS/tts/layers/tortoise/diffusion.py | 1259 +++++++++++++ TTS/tts/layers/tortoise/diffusion_decoder.py | 415 +++++ TTS/tts/layers/tortoise/dpm_solver.py | 1562 +++++++++++++++++ .../tortoise/random_latent_generator.py | 55 + TTS/tts/layers/tortoise/tokenizer.py | 34 + TTS/tts/layers/tortoise/transformer.py | 229 +++ TTS/tts/layers/tortoise/utils.py | 46 + TTS/tts/layers/tortoise/vocoder.py | 401 +++++ TTS/tts/layers/tortoise/wav2vec_alignment.py | 150 ++ TTS/tts/layers/tortoise/xtransformers.py | 1259 +++++++++++++ TTS/tts/models/__init__.py | 2 +- TTS/tts/models/tortoise.py | 900 ++++++++++ TTS/tts/utils/assets/tortoise/tokenizer.json | 1 + TTS/utils/audio/torch_transforms.py | 4 +- TTS/utils/manage.py | 31 +- TTS/utils/synthesizer.py | 68 +- docs/source/index.md | 1 + docs/source/models/tortoise.md | 94 + notebooks/Tortoise.ipynb | 108 ++ requirements.txt | 5 + tests/inference_tests/test_python_api.py | 6 +- 31 files changed, 8298 insertions(+), 35 deletions(-) create mode 100644 TTS/tts/configs/tortoise_config.py create mode 100644 TTS/tts/layers/tortoise/arch_utils.py create mode 100644 TTS/tts/layers/tortoise/audio_utils.py create mode 100644 TTS/tts/layers/tortoise/autoregressive.py create mode 100644 TTS/tts/layers/tortoise/classifier.py create mode 100644 TTS/tts/layers/tortoise/clvp.py create mode 100644 TTS/tts/layers/tortoise/diffusion.py create mode 100644 TTS/tts/layers/tortoise/diffusion_decoder.py create mode 100644 TTS/tts/layers/tortoise/dpm_solver.py create mode 100644 TTS/tts/layers/tortoise/random_latent_generator.py create mode 100644 TTS/tts/layers/tortoise/tokenizer.py create mode 100644 TTS/tts/layers/tortoise/transformer.py create mode 100644 TTS/tts/layers/tortoise/utils.py create mode 100644 TTS/tts/layers/tortoise/vocoder.py create mode 100644 TTS/tts/layers/tortoise/wav2vec_alignment.py create mode 100644 TTS/tts/layers/tortoise/xtransformers.py create mode 100644 TTS/tts/models/tortoise.py create mode 100644 TTS/tts/utils/assets/tortoise/tokenizer.json create mode 100644 docs/source/models/tortoise.md create mode 100644 notebooks/Tortoise.ipynb diff --git a/.github/workflows/inference_tests.yml b/.github/workflows/inference_tests.yml index 828f78b47a..4441d83cf6 100644 --- a/.github/workflows/inference_tests.yml +++ b/.github/workflows/inference_tests.yml @@ -52,4 +52,4 @@ jobs: - name: Unit tests run: make inference_tests env: - COQUI_STUDIO_TOKEN: ${{ secrets.COQUI_STUDIO_TOKEN }} + COQUI_STUDIO_TOKEN: ${{ secrets.COQUI_STUDIO_TOKEN }} \ No newline at end of file diff --git a/TTS/.models.json b/TTS/.models.json index 61dbd7a336..b396e641c7 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -220,6 +220,26 @@ "license": "apache 2.0", "contact": "adamfroghyar@gmail.com" } + + }, + "multi-dataset":{ + "tortoise-v2":{ + "description": "Tortoise tts model https://github.com/neonbjb/tortoise-tts", + "github_rls_url": ["https://coqui.gateway.scarf.sh/v0.14.1_models/autoregressive.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/clvp2.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/cvvp.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/diffusion_decoder.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/rlg_auto.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/rlg_diffuser.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/vocoder.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/mel_norms.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/config.json" + ], + "commit": "c1875f6", + "default_vocoder": null, + "author": "@neonbjb - James Betker, @manmay-nakhashi Manmay Nakhashi", + "license": "apache 2.0" + } }, "jenny": { "jenny":{ diff --git a/TTS/api.py b/TTS/api.py index 8124d3ab56..8bd087f652 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -342,10 +342,14 @@ def list_models(): def download_model_by_name(self, model_name: str): model_path, config_path, model_item = self.manager.download_model(model_name) + if isinstance(model_item["github_rls_url"], list): + # return model directory if there are multiple files + # we assume that the model knows how to load itself + return None, None, None, None, model_path if model_item.get("default_vocoder") is None: - return model_path, config_path, None, None + return model_path, config_path, None, None, None vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"]) - return model_path, config_path, vocoder_path, vocoder_config_path + return model_path, config_path, vocoder_path, vocoder_config_path, None def load_vc_model_by_name(self, model_name: str, gpu: bool = False): """Load one of the voice conversion models by name. @@ -355,7 +359,7 @@ def load_vc_model_by_name(self, model_name: str, gpu: bool = False): gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ self.model_name = model_name - model_path, config_path, _, _ = self.download_model_by_name(model_name) + model_path, config_path, _, _, _ = self.download_model_by_name(model_name) self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu) def load_tts_model_by_name(self, model_name: str, gpu: bool = False): @@ -374,7 +378,9 @@ def load_tts_model_by_name(self, model_name: str, gpu: bool = False): if "coqui_studio" in model_name: self.csapi = CS_API() else: - model_path, config_path, vocoder_path, vocoder_config_path = self.download_model_by_name(model_name) + model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name( + model_name + ) # init synthesizer # None values are fetch from the model @@ -387,6 +393,7 @@ def load_tts_model_by_name(self, model_name: str, gpu: bool = False): vocoder_config=vocoder_config_path, encoder_checkpoint=None, encoder_config=None, + model_dir=model_dir, use_cuda=gpu, ) @@ -422,6 +429,7 @@ def _check_arguments( speaker_wav: str = None, emotion: str = None, speed: float = None, + **kwargs, ) -> None: """Check if the arguments are valid for the model.""" if not self.is_coqui_studio: @@ -430,7 +438,7 @@ def _check_arguments( raise ValueError("Model is multi-speaker but no `speaker` is provided.") if self.is_multi_lingual and language is None: raise ValueError("Model is multi-lingual but no `language` is provided.") - if not self.is_multi_speaker and speaker is not None: + if not self.is_multi_speaker and speaker is not None and "voice_dir" not in kwargs: raise ValueError("Model is not multi-speaker but `speaker` is provided.") if not self.is_multi_lingual and language is not None: raise ValueError("Model is not multi-lingual but `language` is provided.") @@ -499,6 +507,7 @@ def tts( speaker_wav: str = None, emotion: str = None, speed: float = None, + **kwargs, ): """Convert text to speech. @@ -520,12 +529,13 @@ def tts( Speed factor to use for 🐸Coqui Studio models, between 0 and 2.0. If None, Studio models use 1.0. Defaults to None. """ - self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, speed=speed) + self._check_arguments( + speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, speed=speed, **kwargs + ) if self.csapi is not None: return self.tts_coqui_studio( text=text, speaker_name=speaker, language=language, emotion=emotion, speed=speed ) - wav = self.synthesizer.tts( text=text, speaker_name=speaker, @@ -535,6 +545,7 @@ def tts( style_wav=None, style_text=None, reference_speaker_name=None, + **kwargs, ) return wav @@ -547,6 +558,7 @@ def tts_to_file( emotion: str = "Neutral", speed: float = 1.0, file_path: str = "output.wav", + **kwargs, ): """Convert text to speech. @@ -569,13 +581,13 @@ def tts_to_file( file_path (str, optional): Output file path. Defaults to "output.wav". """ - self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav) + self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs) if self.csapi is not None: return self.tts_coqui_studio( text=text, speaker_name=speaker, language=language, emotion=emotion, speed=speed, file_path=file_path ) - wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav) + wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs) self.synthesizer.save_wav(wav=wav, path=file_path) return file_path diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 092264f40e..8a7e178d58 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -274,6 +274,13 @@ def main(): help="Target audio file to convert in the voice of the source_wav", ) + parser.add_argument( + "--voice_dir", + type=str, + default=None, + help="Voice dir for tortoise model", + ) + args = parser.parse_args() # print the description if either text or list_models is not set @@ -306,6 +313,7 @@ def main(): encoder_config_path = None vc_path = None vc_config_path = None + model_dir = None # CASE1 #list : list pre-trained TTS models if args.list_models: @@ -335,7 +343,6 @@ def main(): # CASE4: load pre-trained model paths if args.model_name is not None and not args.model_path: model_path, config_path, model_item = manager.download_model(args.model_name) - # tts model if model_item["model_type"] == "tts_models": tts_path = model_path @@ -348,6 +355,13 @@ def main(): vc_path = model_path vc_config_path = config_path + # tts model with multiple files to be loaded from the directory path + if isinstance(model_item["github_rls_url"], list): + model_dir = model_path + tts_path = None + tts_config_path = None + args.vocoder_name = None + # load vocoder if args.vocoder_name is not None and not args.vocoder_path: vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) @@ -379,6 +393,8 @@ def main(): encoder_config_path, vc_path, vc_config_path, + model_dir, + args.voice_dir, args.use_cuda, ) @@ -427,6 +443,8 @@ def main(): source_wav=args.source_wav, target_wav=args.target_wav, ) + elif model_dir is not None: + wav = synthesizer.tts(args.text, speaker_name=args.speaker_idx) # save the results print(" > Saving output to {}".format(args.out_path)) diff --git a/TTS/tts/configs/tortoise_config.py b/TTS/tts/configs/tortoise_config.py new file mode 100644 index 0000000000..7da94a4c88 --- /dev/null +++ b/TTS/tts/configs/tortoise_config.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass, field + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.tortoise import TortoiseArgs, TortoiseAudioConfig + + +@dataclass +class TortoiseConfig(BaseTTSConfig): + """Defines parameters for Tortoise TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (TortoiseArgs): + Model architecture arguments. Defaults to `TortoiseArgs()`. + + audio (TortoiseAudioConfig): + Audio processing configuration. Defaults to `TortoiseAudioConfig()`. + + model_dir (str): + Path to the folder that has all the Tortoise models. Defaults to None. + + temperature (float): + Temperature for the autoregressive model inference. Larger values makes predictions more creative sacrificing stability. Defaults to `0.2`. + + length_penalty (float): + Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, + which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), + length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. + + reperation_penalty (float): + The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`. + + top_p (float): + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + Defaults to `0.8`. + + cond_free_k (float): + Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k. Defaults to `2.0`. + + diffusion_temperature (float): + Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + Defaults to `1.0`. + + num_autoregressive_samples (int): + Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + Defaults to `16`. + + diffusion_iterations (int): + Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine + the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, + however. Defaults to `30`. + + sampler (str): + Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`. + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.tortoise_config import TortoiseConfig + >>> config = TortoiseConfig() + """ + + model: str = "tortoise" + # model specific params + model_args: TortoiseArgs = field(default_factory=TortoiseArgs) + audio: TortoiseAudioConfig = TortoiseAudioConfig() + model_dir: str = None + + # settings + temperature: float = 0.2 + length_penalty: float = 1.0 + repetition_penalty: float = 2.0 + top_p: float = 0.8 + cond_free_k: float = 2.0 + diffusion_temperature: float = 1.0 + + # inference params + num_autoregressive_samples: int = 16 + diffusion_iterations: int = 30 + sampler: str = "ddim" diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py new file mode 100644 index 0000000000..dad1814369 --- /dev/null +++ b/TTS/tts/layers/tortoise/arch_utils.py @@ -0,0 +1,433 @@ +import functools +import math +import os + +import fsspec +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from transformers import LogitsWarper + +from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + groups = 32 + if channels <= 16: + groups = 8 + elif channels <= 64: + groups = 16 + while channels % groups != 0: + groups = int(groups / 2) + assert groups > 2 + return GroupNorm32(groups, channels) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, mask=None, rel_pos=None): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + if rel_pos is not None: + weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape( + bs * self.n_heads, weight.shape[-2], weight.shape[-1] + ) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + if mask is not None: + # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. + mask = mask.repeat(self.n_heads, 1).unsqueeze(1) + weight = weight * mask + a = torch.einsum("bts,bcs->bct", weight, v) + + return a.reshape(bs, -1, length) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + do_checkpoint=True, + relative_pos_embeddings=False, + ): + super().__init__() + self.channels = channels + self.do_checkpoint = do_checkpoint + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = normalization(channels) + self.qkv = nn.Conv1d(channels, channels * 3, 1) + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) + if relative_pos_embeddings: + self.relative_pos_embeddings = RelativePositionBias( + scale=(channels // self.num_heads) ** 0.5, + causal=False, + heads=num_heads, + num_buckets=32, + max_distance=64, + ) + else: + self.relative_pos_embeddings = None + + def forward(self, x, mask=None): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv, mask, self.relative_pos_embeddings) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + """ + + def __init__(self, channels, use_conv, out_channels=None, factor=4): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.factor = factor + if use_conv: + ksize = 5 + pad = 2 + self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad) + + def forward(self, x): + assert x.shape[1] == self.channels + x = F.interpolate(x, scale_factor=self.factor, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + """ + + def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + + stride = factor + if use_conv: + self.op = nn.Conv1d(self.channels, self.out_channels, ksize, stride=stride, padding=pad) + else: + assert self.channels == self.out_channels + self.op = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(nn.Module): + def __init__( + self, + channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + up=False, + down=False, + kernel_size=3, + ): + super().__init__() + self.channels = channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + padding = 1 if kernel_size == 3 else 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False) + self.x_upd = Upsample(channels, False) + elif down: + self.h_upd = Downsample(channels, False) + self.x_upd = Downsample(channels, False) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding) + else: + self.skip_connection = nn.Conv1d(channels, self.out_channels, 1) + + def forward(self, x): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AudioMiniEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + base_channels=128, + depth=2, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + dropout=0, + downsample_factor=2, + kernel_size=3, + ): + super().__init__() + self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1)) + ch = base_channels + res = [] + for l in range(depth): + for r in range(resnet_blocks): + res.append(ResBlock(ch, dropout, kernel_size=kernel_size)) + res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor)) + ch *= 2 + self.res = nn.Sequential(*res) + self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) + attn = [] + for a in range(attn_blocks): + attn.append( + AttentionBlock( + embedding_dim, + num_attn_heads, + ) + ) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + + def forward(self, x): + h = self.init(x) + h = self.res(h) + h = self.final(h) + h = self.attn(h) + return h[:, :, 0] + + +DEFAULT_MEL_NORM_FILE = "https://coqui.gateway.scarf.sh/v0.14.1_models/mel_norms.pth" + + +class TorchMelSpectrogram(nn.Module): + def __init__( + self, + filter_length=1024, + hop_length=256, + win_length=1024, + n_mel_channels=80, + mel_fmin=0, + mel_fmax=8000, + sampling_rate=22050, + normalize=False, + mel_norm_file=DEFAULT_MEL_NORM_FILE, + ): + super().__init__() + # These are the default tacotron values for the MEL spectrogram. + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.sampling_rate = sampling_rate + self.mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + power=2, + normalized=normalize, + sample_rate=self.sampling_rate, + f_min=self.mel_fmin, + f_max=self.mel_fmax, + n_mels=self.n_mel_channels, + norm="slaney", + ) + self.mel_norm_file = mel_norm_file + if self.mel_norm_file is not None: + with fsspec.open(self.mel_norm_file) as f: + self.mel_norms = torch.load(f) + else: + self.mel_norms = None + + def forward(self, inp): + if ( + len(inp.shape) == 3 + ): # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) + inp = inp.squeeze(1) + assert len(inp.shape) == 2 + self.mel_stft = self.mel_stft.to(inp.device) + mel = self.mel_stft(inp) + # Perform dynamic range compression + mel = torch.log(torch.clamp(mel, min=1e-5)) + if self.mel_norms is not None: + self.mel_norms = self.mel_norms.to(mel.device) + mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) + return mel + + +class CheckpointedLayer(nn.Module): + """ + Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses + checkpoint for all other args. + """ + + def __init__(self, wrap): + super().__init__() + self.wrap = wrap + + def forward(self, x, *args, **kwargs): + for k, v in kwargs.items(): + assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. + partial = functools.partial(self.wrap, **kwargs) + return partial(x, *args) + + +class CheckpointedXTransformerEncoder(nn.Module): + """ + Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid + to channels-last that XTransformer expects. + """ + + def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs): + super().__init__() + self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) + self.needs_permute = needs_permute + self.exit_permute = exit_permute + + if not checkpoint: + return + for i in range(len(self.transformer.attn_layers.layers)): + n, b, r = self.transformer.attn_layers.layers[i] + self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) + + def forward(self, x, **kwargs): + if self.needs_permute: + x = x.permute(0, 2, 1) + h = self.transformer(x, **kwargs) + if self.exit_permute: + h = h.permute(0, 2, 1) + return h + + +class TypicalLogitsWarper(LogitsWarper): + def __init__( + self, + mass: float = 0.9, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, + ): + self.filter_value = filter_value + self.mass = mass + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (cumulative_probs < self.mass).sum(dim=1) + last_ind[last_ind < 0] = 0 + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores diff --git a/TTS/tts/layers/tortoise/audio_utils.py b/TTS/tts/layers/tortoise/audio_utils.py new file mode 100644 index 0000000000..70711ed7a4 --- /dev/null +++ b/TTS/tts/layers/tortoise/audio_utils.py @@ -0,0 +1,177 @@ +import os +from glob import glob +from typing import Dict, List + +import librosa +import numpy as np +import torch +import torchaudio +from scipy.io.wavfile import read + +from TTS.utils.audio.torch_transforms import TorchSTFT + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + if data.dtype == np.int32: + norm_fix = 2**31 + elif data.dtype == np.int16: + norm_fix = 2**15 + elif data.dtype == np.float16 or data.dtype == np.float32: + norm_fix = 1.0 + else: + raise NotImplementedError(f"Provided data dtype not supported: {data.dtype}") + return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate) + + +def check_audio(audio, audiopath: str): + # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. + # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. + if torch.any(audio > 2) or not torch.any(audio < 0): + print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + audio.clip_(-1, 1) + + +def read_audio_file(audiopath: str): + if audiopath[-4:] == ".wav": + audio, lsr = load_wav_to_torch(audiopath) + elif audiopath[-4:] == ".mp3": + audio, lsr = librosa.load(audiopath, sr=None) + audio = torch.FloatTensor(audio) + else: + assert False, f"Unsupported audio format provided: {audiopath[-4:]}" + + # Remove any channel data. + if len(audio.shape) > 1: + if audio.shape[0] < 5: + audio = audio[0] + else: + assert audio.shape[1] < 5 + audio = audio[:, 0] + + return audio, lsr + + +def load_required_audio(audiopath: str): + audio, lsr = read_audio_file(audiopath) + + audios = [torchaudio.functional.resample(audio, lsr, sampling_rate) for sampling_rate in (22050, 24000)] + for audio in audios: + check_audio(audio, audiopath) + + return [audio.unsqueeze(0) for audio in audios] + + +def load_audio(audiopath, sampling_rate): + audio, lsr = read_audio_file(audiopath) + + if lsr != sampling_rate: + audio = torchaudio.functional.resample(audio, lsr, sampling_rate) + check_audio(audio, audiopath) + + return audio.unsqueeze(0) + + +TACOTRON_MEL_MAX = 2.3143386840820312 +TACOTRON_MEL_MIN = -11.512925148010254 + + +def denormalize_tacotron_mel(norm_mel): + return ((norm_mel + 1) / 2) * (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN) + TACOTRON_MEL_MIN + + +def normalize_tacotron_mel(mel): + return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1 + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def get_voices(extra_voice_dirs: List[str] = []): + dirs = extra_voice_dirs + voices: Dict[str, List[str]] = {} + for d in dirs: + subs = os.listdir(d) + for sub in subs: + subj = os.path.join(d, sub) + if os.path.isdir(subj): + voices[sub] = list(glob(f"{subj}/*.wav")) + list(glob(f"{subj}/*.mp3")) + list(glob(f"{subj}/*.pth")) + return voices + + +def load_voice(voice: str, extra_voice_dirs: List[str] = []): + if voice == "random": + return None, None + + voices = get_voices(extra_voice_dirs) + paths = voices[voice] + if len(paths) == 1 and paths[0].endswith(".pth"): + return None, torch.load(paths[0]) + else: + conds = [] + for cond_path in paths: + c = load_required_audio(cond_path) + conds.append(c) + return conds, None + + +def load_voices(voices: List[str], extra_voice_dirs: List[str] = []): + latents = [] + clips = [] + for voice in voices: + if voice == "random": + if len(voices) > 1: + print("Cannot combine a random voice with a non-random voice. Just using a random voice.") + return None, None + clip, latent = load_voice(voice, extra_voice_dirs) + if latent is None: + assert ( + len(latents) == 0 + ), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this." + clips.extend(clip) + elif clip is None: + assert ( + len(clips) == 0 + ), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this." + latents.append(latent) + if len(latents) == 0: + return clips, None + else: + latents_0 = torch.stack([l[0] for l in latents], dim=0).mean(dim=0) + latents_1 = torch.stack([l[1] for l in latents], dim=0).mean(dim=0) + latents = (latents_0, latents_1) + return None, latents + + +def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"): + stft = TorchSTFT( + n_fft=1024, + hop_length=256, + win_length=1024, + use_mel=True, + n_mels=100, + sample_rate=24000, + mel_fmin=0, + mel_fmax=12000, + ) + stft = stft.to(device) + mel = stft(wav) + mel = dynamic_range_compression(mel) + if do_normalization: + mel = normalize_tacotron_mel(mel) + return mel diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py new file mode 100644 index 0000000000..14d881bc10 --- /dev/null +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -0,0 +1,631 @@ +# AGPL: a notification must be added stating that changes have been made to that file. +import functools + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + +from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper + + +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + +def _p(t): + return t and (len(t), len(t[0]), t[0][0].shape) # kv_cache debug + + +class ResBlock(nn.Module): + """ + Basic residual convolutional block that uses GroupNorm. + """ + + def __init__(self, chan): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan // 8, chan), + nn.ReLU(), + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan // 8, chan), + ) + + def forward(self, x): + return F.relu(self.net(x) + x) + + +class GPT2InferenceModel(GPT2PreTrainedModel): + def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache): + super().__init__(config) + self.transformer = gpt + self.text_pos_embedding = text_pos_emb + self.embeddings = embeddings + self.lm_head = nn.Sequential(norm, linear) + self.kv_cache = kv_cache + + def store_mel_emb(self, mel_emb): + self.cached_mel_emb = mel_emb + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + assert self.cached_mel_emb is not None + assert inputs_embeds is None # Not supported by this inference model. + assert labels is None # Training not supported by this inference model. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Create embedding + mel_len = self.cached_mel_emb.shape[1] + if input_ids.shape[1] != 1: + text_inputs = input_ids[:, mel_len:] + text_emb = self.embeddings(text_inputs) + text_emb = text_emb + self.text_pos_embedding(text_emb) + if self.cached_mel_emb.shape[0] != text_emb.shape[0]: + mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0] // self.cached_mel_emb.shape[0], 0) + else: # this outcome only occurs once per loop in most cases + mel_emb = self.cached_mel_emb + emb = torch.cat([mel_emb, text_emb], dim=1) + else: + emb = self.embeddings(input_ids) + emb = emb + self.text_pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - mel_len, attention_mask.device + ) + + transformer_outputs = self.transformer( + inputs_embeds=emb, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + transformer_outputs[1:] + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + +class ConditioningEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + attn_blocks=6, + num_attn_heads=4, + do_checkpointing=False, + mean=False, + ): + super().__init__() + attn = [] + self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + self.do_checkpointing = do_checkpointing + self.mean = mean + + def forward(self, x): + h = self.init(x) + h = self.attn(h) + if self.mean: + return h.mean(dim=2) + else: + return h[:, :, 0] + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_dim, init=0.02): + super().__init__() + self.emb = nn.Embedding(seq_len, model_dim) + # Initializing this way is standard for GPT-2 + self.emb.weight.data.normal_(mean=0.0, std=init) + + def forward(self, x): + sl = x.shape[1] + return self.emb(torch.arange(0, sl, device=x.device)) + + def get_fixed_embedding(self, ind, dev): + return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind] + + +def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): + """ + GPT-2 implemented by the HuggingFace library. + """ + from transformers import GPT2Config, GPT2Model + + gpt_config = GPT2Config( + vocab_size=256, # Unused. + n_positions=max_mel_seq_len + max_text_seq_len, + n_ctx=max_mel_seq_len + max_text_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing, + ) + gpt = GPT2Model(gpt_config) + # Override the built in positional embeddings + del gpt.wpe # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024) + gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + # Built-in token embeddings are unused. + del gpt.wte + return ( + gpt, + LearnedPositionEmbeddings(max_mel_seq_len, model_dim), + LearnedPositionEmbeddings(max_text_seq_len, model_dim), + None, + None, + ) + + +class MelEncoder(nn.Module): + def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): + super().__init__() + self.channels = channels + self.encoder = nn.Sequential( + nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1), + nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 16, channels // 2), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 8, channels), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + ) + self.reduction = 4 + + def forward(self, x): + for e in self.encoder: + x = e(x) + return x.permute(0, 2, 1) + + +class UnifiedVoice(nn.Module): + def __init__( + self, + layers=8, + model_dim=512, + heads=8, + max_text_tokens=120, + max_mel_tokens=250, + max_conditioning_inputs=1, + mel_length_compression=1024, + number_text_tokens=256, + start_text_token=None, + number_mel_codes=8194, + start_mel_token=8192, + stop_mel_token=8193, + train_solo_embeddings=False, + use_mel_codes_as_input=True, + checkpointing=True, + types=1, + ): + """ + Args: + layers: Number of layers in transformer stack. + model_dim: Operating dimensions of the transformer + heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 + max_text_tokens: Maximum number of text tokens that will be encountered by model. + max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. + max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). + mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. + number_text_tokens: + start_text_token: + stop_text_token: + number_mel_codes: + start_mel_token: + stop_mel_token: + train_solo_embeddings: + use_mel_codes_as_input: + checkpointing: + """ + super().__init__() + + self.number_text_tokens = number_text_tokens + self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token + self.stop_text_token = 0 + self.number_mel_codes = number_mel_codes + self.start_mel_token = start_mel_token + self.stop_mel_token = stop_mel_token + self.layers = layers + self.heads = heads + self.max_mel_tokens = max_mel_tokens + self.max_text_tokens = max_text_tokens + self.model_dim = model_dim + self.max_conditioning_inputs = max_conditioning_inputs + self.mel_length_compression = mel_length_compression + self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) + self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim) + if use_mel_codes_as_input: + self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) + else: + self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) + ( + self.gpt, + self.mel_pos_embedding, + self.text_pos_embedding, + self.mel_layer_pos_embedding, + self.text_layer_pos_embedding, + ) = build_hf_gpt_transformer( + layers, + model_dim, + heads, + self.max_mel_tokens + 2 + self.max_conditioning_inputs, + self.max_text_tokens + 2, + checkpointing, + ) + if train_solo_embeddings: + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + else: + self.mel_solo_embedding = 0 + self.text_solo_embedding = 0 + + self.final_norm = nn.LayerNorm(model_dim) + self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1) + self.mel_head = nn.Linear(model_dim, self.number_mel_codes) + + # Initialize the embeddings per the GPT-2 scheme + embeddings = [self.text_embedding] + if use_mel_codes_as_input: + embeddings.append(self.mel_embedding) + for module in embeddings: + module.weight.data.normal_(mean=0.0, std=0.02) + + def post_init_gpt2_config(self, kv_cache=True): + seq_length = self.max_mel_tokens + self.max_text_tokens + 2 + gpt_config = GPT2Config( + vocab_size=self.max_mel_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.inference_model = GPT2InferenceModel( + gpt_config, + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head, + kv_cache=kv_cache, + ) + # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) + self.gpt.wte = self.mel_embedding + # self.inference_model.save_pretrained("") + + def build_aligned_inputs_and_targets(self, input, start_token, stop_token): + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) + return inp, tar + + def set_mel_padding(self, mel_input_tokens, wav_lengths): + """ + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required + preformatting to create a working TTS model. + """ + # Set padding areas within MEL (currently it is coded with the MEL code for ). + mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode="trunc") + for b in range(len(mel_lengths)): + actual_end = ( + mel_lengths[b] + 1 + ) # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. + if actual_end < mel_input_tokens.shape[-1]: + mel_input_tokens[b, actual_end:] = self.stop_mel_token + return mel_input_tokens + + def get_logits( + self, + speech_conditioning_inputs, + first_inputs, + first_head, + second_inputs=None, + second_head=None, + get_attns=False, + return_latent=False, + ): + if second_inputs is not None: + emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) + + gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) + if get_attns: + return gpt_out.attentions + + enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input + enc = self.final_norm(enc) + + if return_latent: + return ( + enc[ + :, + speech_conditioning_inputs.shape[1] : speech_conditioning_inputs.shape[1] + first_inputs.shape[1], + ], + enc[:, -second_inputs.shape[1] :], + ) + + first_logits = enc[:, : first_inputs.shape[1]] + first_logits = first_head(first_logits) + first_logits = first_logits.permute(0, 2, 1) + if second_inputs is not None: + second_logits = enc[:, -second_inputs.shape[1] :] + second_logits = second_head(second_logits) + second_logits = second_logits.permute(0, 2, 1) + return first_logits, second_logits + else: + return first_logits + + def get_conditioning(self, speech_conditioning_input): + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + conds = conds.mean(dim=1) + return conds + + def forward( + self, + speech_conditioning_latent, + text_inputs, + text_lengths, + mel_codes, + wav_lengths, + types=None, + text_first=True, + raw_mels=None, + return_attentions=False, + return_latent=False, + clip_inputs=True, + ): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + speech_conditioning_input: MEL float tensor, (b,1024) + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + raw_mels: MEL float tensor (b,80,s) + + If return_attentions is specified, only logits are returned. + If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. + """ + # Types are expressed by expanding the text embedding space. + if types is not None: + text_inputs = text_inputs * (1 + types).unsqueeze(-1) + + if clip_inputs: + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_text_len = text_lengths.max() + text_inputs = text_inputs[:, :max_text_len] + max_mel_len = wav_lengths.max() // self.mel_length_compression + mel_codes = mel_codes[:, :max_mel_len] + if raw_mels is not None: + raw_mels = raw_mels[:, :, : max_mel_len * 4] + mel_codes = self.set_mel_padding(mel_codes, wav_lengths) + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token) + + conds = speech_conditioning_latent.unsqueeze(1) + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token + ) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + mel_codes, mel_targets = self.build_aligned_inputs_and_targets( + mel_codes, self.start_mel_token, self.stop_mel_token + ) + if raw_mels is not None: + mel_inp = F.pad(raw_mels, (0, 8)) + else: + mel_inp = mel_codes + mel_emb = self.mel_embedding(mel_inp) + mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + + if text_first: + text_logits, mel_logits = self.get_logits( + conds, + text_emb, + self.text_head, + mel_emb, + self.mel_head, + get_attns=return_attentions, + return_latent=return_latent, + ) + if return_latent: + return mel_logits[ + :, :-2 + ] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + else: + mel_logits, text_logits = self.get_logits( + conds, + mel_emb, + self.mel_head, + text_emb, + self.text_head, + get_attns=return_attentions, + return_latent=return_latent, + ) + if return_latent: + return text_logits[ + :, :-2 + ] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + + if return_attentions: + return mel_logits + loss_text = F.cross_entropy(text_logits, text_targets.long()) + loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) + return loss_text.mean(), loss_mel.mean(), mel_logits + + def inference_speech( + self, + speech_conditioning_latent, + text_inputs, + input_tokens=None, + num_return_sequences=1, + max_generate_length=None, + typical_sampling=False, + typical_mass=0.9, + **hf_generate_kwargs, + ): + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token + ) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + conds = speech_conditioning_latent.unsqueeze(1) + emb = torch.cat([conds, text_emb], dim=1) + self.inference_model.store_mel_emb(emb) + + fake_inputs = torch.full( + ( + emb.shape[0], + conds.shape[1] + emb.shape[1], + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + fake_inputs[:, -1] = self.start_mel_token + trunc_index = fake_inputs.shape[1] + if input_tokens is None: + inputs = fake_inputs + else: + assert ( + num_return_sequences % input_tokens.shape[0] == 0 + ), "The number of return sequences must be divisible by the number of input sequences" + fake_inputs = fake_inputs.repeat(num_return_sequences, 1) + input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1) + inputs = torch.cat([fake_inputs, input_tokens], dim=1) + + logits_processor = ( + LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() + ) # TODO disable this + max_length = ( + trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length + ) + gen = self.inference_model.generate( + inputs, + bos_token_id=self.start_mel_token, + pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, + max_length=max_length, + logits_processor=logits_processor, + num_return_sequences=num_return_sequences, + **hf_generate_kwargs, + ) + return gen[:, trunc_index:] + + +if __name__ == "__main__": + gpt = UnifiedVoice( + model_dim=256, + heads=4, + train_solo_embeddings=True, + use_mel_codes_as_input=True, + max_conditioning_inputs=4, + ) + l = gpt( + torch.randn(2, 3, 80, 800), + torch.randint(high=120, size=(2, 120)), + torch.tensor([32, 120]), + torch.randint(high=8192, size=(2, 250)), + torch.tensor([250 * 256, 195 * 256]), + ) + gpt.text_forward( + torch.randn(2, 80, 800), + torch.randint(high=50, size=(2, 80)), + torch.tensor([32, 80]), + ) diff --git a/TTS/tts/layers/tortoise/classifier.py b/TTS/tts/layers/tortoise/classifier.py new file mode 100644 index 0000000000..8764bb070b --- /dev/null +++ b/TTS/tts/layers/tortoise/classifier.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn + +from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, Downsample, Upsample, normalization, zero_module + + +class ResBlock(nn.Module): + def __init__( + self, + channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + kernel_size=3, + do_checkpoint=True, + ): + super().__init__() + self.channels = channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.do_checkpoint = do_checkpoint + padding = 1 if kernel_size == 3 else 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, kernel_size, padding=padding) + else: + self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1) + + def forward(self, x): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AudioMiniEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + base_channels=128, + depth=2, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + dropout=0, + downsample_factor=2, + kernel_size=3, + ): + super().__init__() + self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1)) + ch = base_channels + res = [] + self.layers = depth + for l in range(depth): + for r in range(resnet_blocks): + res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size)) + res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor)) + ch *= 2 + self.res = nn.Sequential(*res) + self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) + attn = [] + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + + def forward(self, x): + h = self.init(x) + h = self.res(h) + h = self.final(h) + for blk in self.attn: + h = blk(h) + return h[:, :, 0] + + +class AudioMiniEncoderWithClassifierHead(nn.Module): + def __init__(self, classes, distribute_zero_label=True, **kwargs): + super().__init__() + self.enc = AudioMiniEncoder(**kwargs) + self.head = nn.Linear(self.enc.dim, classes) + self.num_classes = classes + self.distribute_zero_label = distribute_zero_label + + def forward(self, x, labels=None): + h = self.enc(x) + logits = self.head(h) + if labels is None: + return logits + else: + if self.distribute_zero_label: + oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes) + zeros_indices = (labels == 0).unsqueeze(-1) + # Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise. + zero_extra_mass = torch.full_like( + oh_labels, + dtype=torch.float, + fill_value=0.2 / (self.num_classes - 1), + ) + zero_extra_mass[:, 0] = -0.2 + zero_extra_mass = zero_extra_mass * zeros_indices + oh_labels = oh_labels + zero_extra_mass + else: + oh_labels = labels + loss = nn.functional.cross_entropy(logits, oh_labels) + return loss diff --git a/TTS/tts/layers/tortoise/clvp.py b/TTS/tts/layers/tortoise/clvp.py new file mode 100644 index 0000000000..69b8c17c3f --- /dev/null +++ b/TTS/tts/layers/tortoise/clvp.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from TTS.tts.layers.tortoise.arch_utils import CheckpointedXTransformerEncoder +from TTS.tts.layers.tortoise.transformer import Transformer +from TTS.tts.layers.tortoise.xtransformers import Encoder + + +def exists(val): + return val is not None + + +def masked_mean(t, mask, dim=1): + t = t.masked_fill(~mask[:, :, None], 0.0) + return t.sum(dim=1) / mask.sum(dim=1)[..., None] + + +class CLVP(nn.Module): + """ + CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding + transcribed text. + + Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py + """ + + def __init__( + self, + *, + dim_text=512, + dim_speech=512, + dim_latent=512, + num_text_tokens=256, + text_enc_depth=6, + text_seq_len=120, + text_heads=8, + num_speech_tokens=8192, + speech_enc_depth=6, + speech_heads=8, + speech_seq_len=250, + text_mask_percentage=0, + voice_mask_percentage=0, + wav_token_compression=1024, + use_xformers=False, + ): + super().__init__() + self.text_emb = nn.Embedding(num_text_tokens, dim_text) + self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False) + + self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech) + self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) + + if use_xformers: + self.text_transformer = CheckpointedXTransformerEncoder( + needs_permute=False, + exit_permute=False, + max_seq_len=-1, + attn_layers=Encoder( + dim=dim_text, + depth=text_enc_depth, + heads=text_heads, + ff_dropout=0.1, + ff_mult=2, + attn_dropout=0.1, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ), + ) + self.speech_transformer = CheckpointedXTransformerEncoder( + needs_permute=False, + exit_permute=False, + max_seq_len=-1, + attn_layers=Encoder( + dim=dim_speech, + depth=speech_enc_depth, + heads=speech_heads, + ff_dropout=0.1, + ff_mult=2, + attn_dropout=0.1, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ), + ) + else: + self.text_transformer = Transformer( + causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, heads=text_heads + ) + self.speech_transformer = Transformer( + causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads + ) + + self.temperature = nn.Parameter(torch.tensor(1.0)) + self.text_mask_percentage = text_mask_percentage + self.voice_mask_percentage = voice_mask_percentage + self.wav_token_compression = wav_token_compression + self.xformers = use_xformers + if not use_xformers: + self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) + self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) + + def forward(self, text, speech_tokens, return_loss=False): + b, device = text.shape[0], text.device + if self.training: + text_mask = torch.rand_like(text.float()) > self.text_mask_percentage + voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage + else: + text_mask = torch.ones_like(text.float()).bool() + voice_mask = torch.ones_like(speech_tokens.float()).bool() + + text_emb = self.text_emb(text) + speech_emb = self.speech_emb(speech_tokens) + + if not self.xformers: + text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) + speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) + + enc_text = self.text_transformer(text_emb, mask=text_mask) + enc_speech = self.speech_transformer(speech_emb, mask=voice_mask) + + text_latents = masked_mean(enc_text, text_mask, dim=1) + speech_latents = masked_mean(enc_speech, voice_mask, dim=1) + + text_latents = self.to_text_latent(text_latents) + speech_latents = self.to_speech_latent(speech_latents) + + text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) + + temp = self.temperature.exp() + + if not return_loss: + sim = einsum("n d, n d -> n", text_latents, speech_latents) * temp + return sim + + sim = einsum("i d, j d -> i j", text_latents, speech_latents) * temp + labels = torch.arange(b, device=device) + loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + return loss + + +if __name__ == "__main__": + clip = CLVP(text_mask_percentage=0.2, voice_mask_percentage=0.2) + clip( + torch.randint(0, 256, (2, 120)), + torch.tensor([50, 100]), + torch.randint(0, 8192, (2, 250)), + torch.tensor([101, 102]), + return_loss=True, + ) + nonloss = clip( + torch.randint(0, 256, (2, 120)), + torch.tensor([50, 100]), + torch.randint(0, 8192, (2, 250)), + torch.tensor([101, 102]), + return_loss=False, + ) + print(nonloss.shape) diff --git a/TTS/tts/layers/tortoise/diffusion.py b/TTS/tts/layers/tortoise/diffusion.py new file mode 100644 index 0000000000..eb9e90df51 --- /dev/null +++ b/TTS/tts/layers/tortoise/diffusion.py @@ -0,0 +1,1259 @@ +""" +This is an almost carbon copy of gaussian_diffusion.py from OpenAI's ImprovedDiffusion repo, which itself: + +This code started out as a PyTorch port of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py + +Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. +""" + +import enum +import math + +import numpy as np +import torch +import torch as th +from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral +from tqdm import tqdm + +from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper + +K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m} +SAMPLERS = ["dpm++2m", "p", "ddim"] + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = "previous_x" # the model predicts x_{t-1} + START_X = "start_x" # the model predicts x_0 + EPSILON = "epsilon" # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = "learned" + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED_RANGE = "learned_range" + + +class LossType(enum.Enum): + MSE = "mse" # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances) + KL = "kl" # use the variational lower-bound + RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + conditioning_free=False, + conditioning_free_k=1, + ramp_conditioning_free=True, + sampler="p", + ): + self.sampler = sampler + self.model_mean_type = ModelMeanType(model_mean_type) + self.model_var_type = ModelVarType(model_var_type) + self.loss_type = LossType(loss_type) + self.rescale_timesteps = rescale_timesteps + self.conditioning_free = conditioning_free + self.conditioning_free_k = conditioning_free_k + self.ramp_conditioning_free = ramp_conditioning_free + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + if self.conditioning_free: + model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs) + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.conditioning_free: + model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + if self.conditioning_free: + if self.ramp_conditioning_free: + assert t.shape[0] == 1 # This should only be used in inference. + cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps) + else: + cfk = self.conditioning_free_k + model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + else: + raise NotImplementedError(self.model_mean_type) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, self._scale_timesteps(t), **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def k_diffusion_sample_loop( + self, + k_sampler, + pbar, + model, + shape, + noise=None, # all given + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + device=None, # ALL UNUSED + model_kwargs=None, # {'precomputed_aligned_embeddings': precomputed_embeddings}, + progress=False, # unused as well + ): + assert isinstance(model_kwargs, dict) + if device is None: + device = next(model.parameters()).device + s_in = noise.new_ones([noise.shape[0]]) + + def model_split(*args, **kwargs): + model_output = model(*args, **kwargs) + model_epsilon, model_var = th.split(model_output, model_output.shape[1] // 2, dim=1) + return model_epsilon, model_var + + # + """ + print(self.betas) + print(th.tensor(self.betas)) + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas)) + """ + noise_schedule = NoiseScheduleVP(schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4) + + def model_fn_prewrap(x, t, *args, **kwargs): + """ + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + print(t) + print(self.timestep_map) + exit() + """ + """ + model_output = model(x, self._scale_timesteps(t*4000), **model_kwargs) + out = self.p_mean_variance(model, x, t*4000, model_kwargs=model_kwargs) + return out['pred_xstart'] + """ + x, _ = x.chunk(2) + t, _ = (t * 1000).chunk(2) + res = torch.cat( + [ + model_split(x, t, conditioning_free=True, **model_kwargs)[0], + model_split(x, t, **model_kwargs)[0], + ] + ) + pbar.update(1) + return res + + model_fn = model_wrapper( + model_fn_prewrap, + noise_schedule, + model_type="noise", # "noise" or "x_start" or "v" or "score" + model_kwargs=model_kwargs, + guidance_type="classifier-free", + condition=th.Tensor(1), + unconditional_condition=th.Tensor(1), + guidance_scale=self.conditioning_free_k, + ) + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + x_sample = dpm_solver.sample( + noise, + steps=self.num_timesteps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + #''' + return x_sample + + def sample_loop(self, *args, **kwargs): + s = self.sampler + if s == "p": + return self.p_sample_loop(*args, **kwargs) + elif s == "ddim": + return self.ddim_sample_loop(*args, **kwargs) + elif s == "dpm++2m": + if self.conditioning_free is not True: + raise RuntimeError("cond_free must be true") + with tqdm(total=self.num_timesteps) as pbar: + return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs) + else: + raise RuntimeError("sampler not impl") + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + for i in tqdm(indices, disable=not progress): + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) + # Equation 12. + noise = th.randn_like(x) + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices, disable=not progress) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + # TODO: support multiple model outputs for this mode. + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) + if isinstance(model_outputs, tuple): + model_output = model_outputs[0] + terms["extra_outputs"] = model_outputs[1:] + else: + model_output = model_outputs + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] + x_start_pred = torch.zeros(x_start) # Not supported. + elif self.model_mean_type == ModelMeanType.START_X: + target = x_start + x_start_pred = model_output + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) + else: + raise NotImplementedError(self.model_mean_type) + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + terms["x_start_predicted"] = x_start_pred + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def autoregressive_training_losses( + self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None + ): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + terms = {} + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + assert False # not currently supported for this type of diffusion. + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs) + terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) + model_output = terms[gd_out_key] + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C, 2, *x_t.shape[2:]) + model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1] + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] + x_start_pred = torch.zeros(x_start) # Not supported. + elif self.model_mean_type == ModelMeanType.START_X: + target = x_start + x_start_pred = model_output + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) + else: + raise NotImplementedError(self.model_mean_type) + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + terms["x_start_predicted"] = x_start_pred + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def autoregressive_training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model, autoregressive=False): + if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel): + return model + mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel + return mod(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError(f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + model_output = self.model(x, new_ts, **kwargs) + return model_output + + +class _WrappedAutoregressiveModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, x0, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, x0, new_ts, **kwargs) + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/TTS/tts/layers/tortoise/diffusion_decoder.py b/TTS/tts/layers/tortoise/diffusion_decoder.py new file mode 100644 index 0000000000..0d3cf7698a --- /dev/null +++ b/TTS/tts/layers/tortoise/diffusion_decoder.py @@ -0,0 +1,415 @@ +import math +import random +from abc import abstractmethod + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import autocast + +from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, normalization + + +def is_latent(t): + return t.dtype == torch.float + + +def is_sequence(t): + return t.dtype == torch.long + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class TimestepBlock(nn.Module): + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class ResBlock(TimestepBlock): + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + dims=2, + kernel_size=3, + efficient_config=True, + use_scale_shift_norm=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_scale_shift_norm = use_scale_shift_norm + padding = {1: 0, 3: 1, 5: 2}[kernel_size] + eff_kernel = 1 if efficient_config else 3 + eff_padding = 0 if efficient_config else 1 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding), + ) + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding) + + def forward(self, x, emb): + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class DiffusionLayer(TimestepBlock): + def __init__(self, model_channels, dropout, num_heads): + super().__init__() + self.resblk = ResBlock( + model_channels, + model_channels, + dropout, + model_channels, + dims=1, + use_scale_shift_norm=True, + ) + self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) + + def forward(self, x, time_emb): + y = self.resblk(x, time_emb) + return self.attn(y) + + +class DiffusionTts(nn.Module): + def __init__( + self, + model_channels=512, + num_layers=8, + in_channels=100, + in_latent_channels=512, + in_tokens=8193, + out_channels=200, # mean and variance + dropout=0, + use_fp16=False, + num_heads=16, + # Parameters for regularization. + layer_drop=0.1, + unconditioned_percentage=0.1, # This implements a mechanism similar to what is used in classifier-free training. + ): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.dropout = dropout + self.num_heads = num_heads + self.unconditioned_percentage = unconditioned_percentage + self.enable_fp16 = use_fp16 + self.layer_drop = layer_drop + + self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1) + self.time_embed = nn.Sequential( + nn.Linear(model_channels, model_channels), + nn.SiLU(), + nn.Linear(model_channels, model_channels), + ) + + # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. + # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally + # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive + # transformer network. + self.code_embedding = nn.Embedding(in_tokens, model_channels) + self.code_converter = nn.Sequential( + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) + self.code_norm = normalization(model_channels) + self.latent_conditioner = nn.Sequential( + nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) + self.contextual_embedder = nn.Sequential( + nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2), + nn.Conv1d(model_channels, model_channels * 2, 3, padding=1, stride=2), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + ) + self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1)) + self.conditioning_timestep_integrator = TimestepEmbedSequential( + DiffusionLayer(model_channels, dropout, num_heads), + DiffusionLayer(model_channels, dropout, num_heads), + DiffusionLayer(model_channels, dropout, num_heads), + ) + + self.integrating_conv = nn.Conv1d(model_channels * 2, model_channels, kernel_size=1) + self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) + + self.layers = nn.ModuleList( + [DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + + [ + ResBlock( + model_channels, + model_channels, + dropout, + dims=1, + use_scale_shift_norm=True, + ) + for _ in range(3) + ] + ) + + self.out = nn.Sequential( + normalization(model_channels), + nn.SiLU(), + nn.Conv1d(model_channels, out_channels, 3, padding=1), + ) + + def get_grad_norm_parameter_groups(self): + groups = { + "minicoder": list(self.contextual_embedder.parameters()), + "layers": list(self.layers.parameters()), + "code_converters": list(self.code_embedding.parameters()) + + list(self.code_converter.parameters()) + + list(self.latent_conditioner.parameters()) + + list(self.latent_conditioner.parameters()), + "timestep_integrator": list(self.conditioning_timestep_integrator.parameters()) + + list(self.integrating_conv.parameters()), + "time_embed": list(self.time_embed.parameters()), + } + return groups + + def get_conditioning(self, conditioning_input): + speech_conditioning_input = ( + conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) + conds = torch.cat(conds, dim=-1) + conds = conds.mean(dim=-1) + return conds + + def timestep_independent( + self, + aligned_conditioning, + conditioning_latent, + expected_seq_len, + return_code_pred, + ): + # Shuffle aligned_latent to BxCxS format + if is_latent(aligned_conditioning): + aligned_conditioning = aligned_conditioning.permute(0, 2, 1) + + cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1) + if is_latent(aligned_conditioning): + code_emb = self.latent_conditioner(aligned_conditioning) + else: + code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) + code_emb = self.code_converter(code_emb) + code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) + + unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) + # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. + if self.training and self.unconditioned_percentage > 0: + unconditioned_batches = ( + torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage + ) + code_emb = torch.where( + unconditioned_batches, + self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), + code_emb, + ) + expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode="nearest") + + if not return_code_pred: + return expanded_code_emb + else: + mel_pred = self.mel_head(expanded_code_emb) + # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss. + mel_pred = mel_pred * unconditioned_batches.logical_not() + return expanded_code_emb, mel_pred + + def forward( + self, + x, + timesteps, + aligned_conditioning=None, + conditioning_latent=None, + precomputed_aligned_embeddings=None, + conditioning_free=False, + return_code_pred=False, + ): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced. + :param conditioning_latent: a pre-computed conditioning latent; see get_conditioning(). + :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent() + :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. + :return: an [N x C x ...] Tensor of outputs. + """ + assert precomputed_aligned_embeddings is not None or ( + aligned_conditioning is not None and conditioning_latent is not None + ) + assert not ( + return_code_pred and precomputed_aligned_embeddings is not None + ) # These two are mutually exclusive. + + unused_params = [] + if conditioning_free: + code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) + unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + unused_params.extend(list(self.latent_conditioner.parameters())) + else: + if precomputed_aligned_embeddings is not None: + code_emb = precomputed_aligned_embeddings + else: + code_emb, mel_pred = self.timestep_independent( + aligned_conditioning, conditioning_latent, x.shape[-1], True + ) + if is_latent(aligned_conditioning): + unused_params.extend( + list(self.code_converter.parameters()) + list(self.code_embedding.parameters()) + ) + else: + unused_params.extend(list(self.latent_conditioner.parameters())) + + unused_params.append(self.unconditioned_embedding) + + time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) + x = self.inp_block(x) + x = torch.cat([x, code_emb], dim=1) + x = self.integrating_conv(x) + for i, lyr in enumerate(self.layers): + # Do layer drop where applicable. Do not drop first and last layers. + if ( + self.training + and self.layer_drop > 0 + and i != 0 + and i != (len(self.layers) - 1) + and random.random() < self.layer_drop + ): + unused_params.extend(list(lyr.parameters())) + else: + # First and last blocks will have autocast disabled for improved precision. + with autocast(x.device.type, enabled=self.enable_fp16 and i != 0): + x = lyr(x, time_emb) + + x = x.float() + out = self.out(x) + + # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. + extraneous_addition = 0 + for p in unused_params: + extraneous_addition = extraneous_addition + p.mean() + out = out + extraneous_addition * 0 + + if return_code_pred: + return out, mel_pred + return out + + +if __name__ == "__main__": + clip = torch.randn(2, 100, 400) + aligned_latent = torch.randn(2, 388, 512) + aligned_sequence = torch.randint(0, 8192, (2, 100)) + cond = torch.randn(2, 100, 400) + ts = torch.LongTensor([600, 600]) + model = DiffusionTts(512, layer_drop=0.3, unconditioned_percentage=0.5) + # Test with latent aligned conditioning + # o = model(clip, ts, aligned_latent, cond) + # Test with sequence aligned conditioning + o = model(clip, ts, aligned_sequence, cond) diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py new file mode 100644 index 0000000000..c70888df42 --- /dev/null +++ b/TTS/tts/layers/tortoise/dpm_solver.py @@ -0,0 +1,1562 @@ +import math + +import torch + + +class NoiseScheduleVP: + def __init__( + self, + schedule="discrete", + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20.0, + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ["discrete", "linear", "cosine"]: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule + ) + ) + + self.schedule = schedule + if schedule == "discrete": + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1.0 + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + self.log_alpha_array = log_alphas.reshape( + ( + 1, + -1, + ) + ).to(dtype=dtype) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999.0 + self.cosine_t_max = ( + math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)) + self.schedule = schedule + if schedule == "cosine": + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1.0 + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == "discrete": + return interpolate_fn( + t.reshape((-1, 1)), + self.t_array.to(t.device), + self.log_alpha_array.to(t.device), + ).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == "cosine": + + def log_alpha_fn(s): + return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) + + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]), + ) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + + def t_fn(log_alpha_t): + return ( + torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - alpha_t * output) / sigma_t + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return alpha_t * output + sigma_t * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -sigma_t * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * sigma_t * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1.0 or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1.0, + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims( + torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), + dims, + ) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == "logSNR": + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == "time_uniform": + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == "time_quadratic": + t_order = 2 + t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) + ) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [ + 3, + ] * ( + K - 2 + ) + [2, 1] + elif steps % 3 == 1: + orders = [ + 3, + ] * ( + K - 1 + ) + [1] + else: + orders = [ + 3, + ] * ( + K - 1 + ) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [ + 2, + ] * K + else: + K = steps // 2 + 1 + orders = [ + 2, + ] * ( + K - 1 + ) + [1] + elif order == 1: + K = 1 + orders = [ + 1, + ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == "logSNR": + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum( + torch.tensor( + [ + 0, + ] + + orders + ), + 0, + ).to(device) + ] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s + if return_intermediate: + return x_t, {"model_s": model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s + if return_intermediate: + return x_t, {"model_s": model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update( + self, + x, + s, + t, + r1=0.5, + model_s=None, + return_intermediate=False, + solver_type="dpmsolver", + ): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(t), + ) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update( + self, + x, + s, + t, + r1=1.0 / 3.0, + r2=2.0 / 3.0, + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type="dpmsolver", + ): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1.0 / 3.0 + if r2 is None: + r2 = 2.0 / 3.0 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(s2), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_s2, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(s2), + ns.marginal_std(t), + ) + alpha_s1, alpha_s2, alpha_t = ( + torch.exp(log_alpha_s1), + torch.exp(log_alpha_s2), + torch.exp(log_alpha_t), + ) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == "dpmsolver": + x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0 + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.0)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == "taylor": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.0)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_2), + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + return x_t + + def singlestep_dpm_solver_update( + self, + x, + s, + t, + order, + return_intermediate=False, + solver_type="dpmsolver", + r1=None, + r2=None, + ): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update( + x, + s, + t, + return_intermediate=return_intermediate, + solver_type=solver_type, + r1=r1, + ) + elif order == 3: + return self.singlestep_dpm_solver_third_update( + x, + s, + t, + return_intermediate=return_intermediate, + solver_type=solver_type, + r1=r1, + r2=r2, + ) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive( + self, + x, + order, + t_T, + t_0, + h_init=0.05, + atol=0.0078, + rtol=0.05, + theta=0.9, + t_err=1e-5, + solver_type="dpmsolver", + ): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + + def lower_update(x, s, t): + return self.dpm_solver_first_update(x, s, t, return_intermediate=True) + + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + + elif order == 3: + r1, r2 = 1.0 / 3.0, 2.0 / 3.0 + + def lower_update(x, s, t): + return self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type + ) + + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max( + torch.ones_like(x).to(x) * atol, + rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)), + ) + + def norm_fn(v): + return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.0): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min( + theta * h * torch.float_power(E, -1.0 / order).float(), + lambda_0 - lambda_s, + ) + nfe += order + print("adaptive solver nfe", nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample( + x, + steps=steps, + t_start=t_0, + t_end=t_T, + order=order, + skip_type=skip_type, + method=method, + lower_order_final=lower_order_final, + denoise_to_zero=denoise_to_zero, + solver_type=solver_type, + atol=atol, + rtol=rtol, + return_intermediate=return_intermediate, + ) + + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == "adaptive": + x = self.dpm_solver_adaptive( + x, + order=order, + t_T=t_T, + t_0=t_0, + atol=atol, + rtol=rtol, + solver_type=solver_type, + ) + elif method == "multistep": + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update( + x, + model_prev_list, + t_prev_list, + t, + step, + solver_type=solver_type, + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update( + x, + model_prev_list, + t_prev_list, + t, + step_order, + solver_type=solver_type, + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ["singlestep", "singlestep_fixed"]: + if method == "singlestep": + ( + timesteps_outer, + orders, + ) = self.get_orders_and_timesteps_for_singlestep_solver( + steps=steps, + order=order, + skip_type=skip_type, + t_T=t_T, + t_0=t_0, + device=device, + ) + elif method == "singlestep_fixed": + K = steps // order + orders = [ + order, + ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps( + skip_type=skip_type, + t_T=s.item(), + t_0=t.item(), + N=order, + device=device, + ) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError("Got wrong method {}".format(method)) + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + +############################################################# +# other utility functions +############################################################# + + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] diff --git a/TTS/tts/layers/tortoise/random_latent_generator.py b/TTS/tts/layers/tortoise/random_latent_generator.py new file mode 100644 index 0000000000..9b39c1e4b2 --- /dev/null +++ b/TTS/tts/layers/tortoise/random_latent_generator.py @@ -0,0 +1,55 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5): + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), + negative_slope=negative_slope, + ) + * scale + ) + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + return out + + +class RandomLatentConverter(nn.Module): + def __init__(self, channels): + super().__init__() + self.layers = nn.Sequential( + *[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)], nn.Linear(channels, channels) + ) + self.channels = channels + + def forward(self, ref): + r = torch.randn(ref.shape[0], self.channels, device=ref.device) + y = self.layers(r) + return y + + +if __name__ == "__main__": + model = RandomLatentConverter(512) + model(torch.randn(5, 512)) diff --git a/TTS/tts/layers/tortoise/tokenizer.py b/TTS/tts/layers/tortoise/tokenizer.py new file mode 100644 index 0000000000..3e544ee7e2 --- /dev/null +++ b/TTS/tts/layers/tortoise/tokenizer.py @@ -0,0 +1,34 @@ +import os + +import torch +from tokenizers import Tokenizer + +from TTS.tts.utils.text.cleaners import english_cleaners + +DEFAULT_VOCAB_FILE = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json" +) + + +class VoiceBpeTokenizer: + def __init__(self, vocab_file=DEFAULT_VOCAB_FILE): + if vocab_file is not None: + self.tokenizer = Tokenizer.from_file(vocab_file) + + def preprocess_text(self, txt): + txt = english_cleaners(txt) + return txt + + def encode(self, txt): + txt = self.preprocess_text(txt) + txt = txt.replace(" ", "[SPACE]") + return self.tokenizer.encode(txt).ids + + def decode(self, seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().numpy() + txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "") + txt = txt.replace("[SPACE]", " ") + txt = txt.replace("[STOP]", "") + txt = txt.replace("[UNK]", "") + return txt diff --git a/TTS/tts/layers/tortoise/transformer.py b/TTS/tts/layers/tortoise/transformer.py new file mode 100644 index 0000000000..70d46aa3e0 --- /dev/null +++ b/TTS/tts/layers/tortoise/transformer.py @@ -0,0 +1,229 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val, depth=1): + if isinstance(val, list): + val = tuple(val) + return val if isinstance(val, tuple) else (val,) * depth + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def stable_softmax(t, dim=-1, alpha=32**2): + t = t / alpha + t = t - torch.amax(t, dim=dim, keepdim=True).detach() + return (t * alpha).softmax(dim=dim) + + +def route_args(router, args, depth): + routed_args = [(dict(), dict()) for _ in range(depth)] + matched_keys = [key for key in args.keys() if key in router] + + for key in matched_keys: + val = args[key] + for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): + new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) + routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) + return routed_args + + +# classes +class SequentialSequence(nn.Module): + def __init__(self, layers, args_route={}, layer_dropout=0.0): + super().__init__() + assert all( + len(route) == len(layers) for route in args_route.values() + ), "each argument route map must have the same depth as the number of sequential layers" + self.layers = layers + self.args_route = args_route + self.layer_dropout = layer_dropout + + def forward(self, x, **kwargs): + args = route_args(self.args_route, kwargs, len(self.layers)) + layers_and_args = list(zip(self.layers, args)) + + for (f, g), (f_args, g_args) in layers_and_args: + x = x + f(x, **f_args) + x = x + g(x, **g_args) + return x + + +class DivideMax(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + maxes = x.amax(dim=self.dim, keepdim=True).detach() + return x / maxes + + +# https://arxiv.org/abs/2103.17239 +class LayerScale(nn.Module): + def __init__(self, dim, depth, fn): + super().__init__() + if depth <= 18: + init_eps = 0.1 + elif depth > 18 and depth <= 24: + init_eps = 1e-5 + else: + init_eps = 1e-6 + + scale = torch.zeros(1, 1, dim).fill_(init_eps) + self.scale = nn.Parameter(scale) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) * self.scale + + +# layer norm + + +class PreNorm(nn.Module): + def __init__(self, dim, fn, sandwich=False): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() + self.fn = fn + + def forward(self, x, **kwargs): + x = self.norm(x) + x = self.fn(x, **kwargs) + return self.norm_out(x) + + +# feed forward + + +class GEGLU(nn.Module): + def forward(self, x): + x, gates = x.chunk(2, dim=-1) + return x * F.gelu(gates) + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0.0, mult=4.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim * mult * 2), + GEGLU(), + nn.Dropout(dropout), + nn.Linear(dim * mult, dim), + ) + + def forward(self, x): + return self.net(x) + + +# Attention + + +class Attention(nn.Module): + def __init__(self, dim, seq_len, causal=True, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.seq_len = seq_len + self.scale = dim_head**-0.5 + + self.causal = causal + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + + def forward(self, x, mask=None): + b, n, _, h, device = *x.shape, self.heads, x.device + softmax = torch.softmax + + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) + + q = q * self.scale + + dots = torch.einsum("b h i d, b h j d -> b h i j", q, k) + mask_value = max_neg_value(dots) + + if exists(mask): + mask = rearrange(mask, "b j -> b () () j") + dots.masked_fill_(~mask, mask_value) + del mask + + if self.causal: + i, j = dots.shape[-2:] + mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool() + dots.masked_fill_(mask, mask_value) + + attn = softmax(dots, dim=-1) + + out = torch.einsum("b h i j, b h j d -> b h i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.to_out(out) + return out + + +# main transformer class +class Transformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + seq_len, + causal=True, + heads=8, + dim_head=64, + ff_mult=4, + attn_dropout=0.0, + ff_dropout=0.0, + sparse_attn=False, + sandwich_norm=False, + ): + super().__init__() + layers = nn.ModuleList([]) + sparse_layer = cast_tuple(sparse_attn, depth) + + for ind, sparse_attn in zip(range(depth), sparse_layer): + attn = Attention( + dim, + causal=causal, + seq_len=seq_len, + heads=heads, + dim_head=dim_head, + dropout=attn_dropout, + ) + + ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout) + + layers.append( + nn.ModuleList( + [ + LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)), + LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)), + ] + ) + ) + + execute_type = SequentialSequence + route_attn = ((True, False),) * depth + attn_route_map = {"mask": route_attn} + + self.layers = execute_type(layers, args_route=attn_route_map) + + def forward(self, x, **kwargs): + return self.layers(x, **kwargs) diff --git a/TTS/tts/layers/tortoise/utils.py b/TTS/tts/layers/tortoise/utils.py new file mode 100644 index 0000000000..810a9e7f7a --- /dev/null +++ b/TTS/tts/layers/tortoise/utils.py @@ -0,0 +1,46 @@ +import os +from urllib import request + +from tqdm import tqdm + +DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models") +MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR) +MODELS_DIR = "/data/speech_synth/models/" +MODELS = { + "autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth", + "classifier.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth", + "clvp2.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth", + "diffusion_decoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth", + "vocoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth", + "rlg_auto.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth", + "rlg_diffuser.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth", +} + + +def download_models(specific_models=None): + """ + Call to download all the models that Tortoise uses. + """ + os.makedirs(MODELS_DIR, exist_ok=True) + for model_name, url in MODELS.items(): + if specific_models is not None and model_name not in specific_models: + continue + model_path = os.path.join(MODELS_DIR, model_name) + if os.path.exists(model_path): + continue + print(f"Downloading {model_name} from {url}...") + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n)) + print("Done.") + + +def get_model_path(model_name, models_dir=MODELS_DIR): + """ + Get path to given model, download it if it doesn't exist. + """ + if model_name not in MODELS: + raise ValueError(f"Model {model_name} not found in available models.") + model_path = os.path.join(models_dir, model_name) + if not os.path.exists(model_path) and models_dir == MODELS_DIR: + download_models([model_name]) + return model_path diff --git a/TTS/tts/layers/tortoise/vocoder.py b/TTS/tts/layers/tortoise/vocoder.py new file mode 100644 index 0000000000..47365eb58d --- /dev/null +++ b/TTS/tts/layers/tortoise/vocoder.py @@ -0,0 +1,401 @@ +import json +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +MAX_WAV_VALUE = 32768.0 + + +class KernelPredictor(torch.nn.Module): + """Kernel predictor for the location-variable convolutions""" + + def __init__( + self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1}, + ): + """ + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): number of layers + """ + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w + kpnet_bias_channels = conv_out_channels * conv_layers # l_b + + self.input_conv = nn.Sequential( + nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_convs = nn.ModuleList() + padding = (kpnet_conv_size - 1) // 2 + for _ in range(3): + self.residual_convs.append( + nn.Sequential( + nn.Dropout(kpnet_dropout), + nn.utils.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + nn.utils.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + ) + self.kernel_conv = nn.utils.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_kernel_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + self.bias_conv = nn.utils.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_bias_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + + def forward(self, c): + """ + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + """ + batch, _, cond_length = c.shape + c = self.input_conv(c) + for residual_conv in self.residual_convs: + residual_conv.to(c.device) + c = c + residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + kernels = k.contiguous().view( + batch, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + cond_length, + ) + bias = b.contiguous().view( + batch, + self.conv_layers, + self.conv_out_channels, + cond_length, + ) + + return kernels, bias + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.input_conv[0]) + nn.utils.remove_weight_norm(self.kernel_conv) + nn.utils.remove_weight_norm(self.bias_conv) + for block in self.residual_convs: + nn.utils.remove_weight_norm(block[1]) + nn.utils.remove_weight_norm(block[3]) + + +class LVCBlock(torch.nn.Module): + """the location-variable convolutions""" + + def __init__( + self, + in_channels, + cond_channels, + stride, + dilations=[1, 3, 9, 27], + lReLU_slope=0.2, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = len(dilations) + self.conv_kernel_size = conv_kernel_size + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=len(dilations), + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout, + kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, + ) + + self.convt_pre = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.weight_norm( + nn.ConvTranspose1d( + in_channels, + in_channels, + 2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding=stride % 2, + ) + ), + ) + + self.conv_blocks = nn.ModuleList() + for dilation in dilations: + self.conv_blocks.append( + nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.weight_norm( + nn.Conv1d( + in_channels, + in_channels, + conv_kernel_size, + padding=dilation * (conv_kernel_size - 1) // 2, + dilation=dilation, + ) + ), + nn.LeakyReLU(lReLU_slope), + ) + ) + + def forward(self, x, c): + """forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + """ + _, in_channels, _ = x.shape # (B, c_g, L') + + x = self.convt_pre(x) # (B, c_g, stride * L') + kernels, bias = self.kernel_predictor(c) + + for i, conv in enumerate(self.conv_blocks): + output = conv(x) # (B, c_g, stride * L') + + k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) + b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) + + output = self.location_variable_convolution( + output, k, b, hop_size=self.cond_hop_length + ) # (B, 2 * c_g, stride * L'): LVC + x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( + output[:, in_channels:, :] + ) # (B, c_g, stride * L'): GAU + + return x + + def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): + """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + """ + batch, _, in_length = x.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), "constant", 0) + x = x.unfold( + 3, dilation, dilation + ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum("bildsk,biokl->bolsd", x, kernel) + o = o.to(memory_format=torch.channels_last_3d) + bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) + o = o + bias + o = o.contiguous().view(batch, out_channels, -1) + + return o + + def remove_weight_norm(self): + self.kernel_predictor.remove_weight_norm() + nn.utils.remove_weight_norm(self.convt_pre[1]) + for block in self.conv_blocks: + nn.utils.remove_weight_norm(block[1]) + + +class UnivNetGenerator(nn.Module): + """ + UnivNet Generator + + Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py. + """ + + def __init__( + self, + noise_dim=64, + channel_size=32, + dilations=[1, 3, 9, 27], + strides=[8, 8, 4], + lReLU_slope=0.2, + kpnet_conv_size=3, + # Below are MEL configurations options that this generator requires. + hop_length=256, + n_mel_channels=100, + ): + super(UnivNetGenerator, self).__init__() + self.mel_channel = n_mel_channels + self.noise_dim = noise_dim + self.hop_length = hop_length + channel_size = channel_size + kpnet_conv_size = kpnet_conv_size + + self.res_stack = nn.ModuleList() + hop_length = 1 + for stride in strides: + hop_length = stride * hop_length + self.res_stack.append( + LVCBlock( + channel_size, + n_mel_channels, + stride=stride, + dilations=dilations, + lReLU_slope=lReLU_slope, + cond_hop_length=hop_length, + kpnet_conv_size=kpnet_conv_size, + ) + ) + + self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")) + + self.conv_post = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")), + nn.Tanh(), + ) + + def forward(self, c, z): + """ + Args: + c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length) + z (Tensor): the noise sequence (batch, noise_dim, in_length) + + """ + z = self.conv_pre(z) # (B, c_g, L) + + for res_block in self.res_stack: + res_block.to(z.device) + z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i) + + z = self.conv_post(z) # (B, 1, L * 256) + + return z + + def eval(self, inference=False): + super(UnivNetGenerator, self).eval() + # don't remove weight norm while validation in training loop + if inference: + self.remove_weight_norm() + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv_pre) + + for layer in self.conv_post: + if len(layer.state_dict()) != 0: + nn.utils.remove_weight_norm(layer) + + for res_block in self.res_stack: + res_block.remove_weight_norm() + + def inference(self, c, z=None): + # pad input mel with zeros to cut artifact + # see https://github.com/seungwonpark/melgan/issues/8 + zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device) + mel = torch.cat((c, zero), dim=2) + + if z is None: + z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device) + + audio = self.forward(mel, z) + audio = audio[:, :, : -(self.hop_length * 10)] + audio = audio.clamp(min=-1, max=1) + return audio + + +@dataclass +class VocType: + constructor: Callable[[], nn.Module] + model_path: str + subkey: Optional[str] = None + + def optionally_index(self, model_dict): + if self.subkey is not None: + return model_dict[self.subkey] + return model_dict + + +class VocConf(Enum): + Univnet = VocType(UnivNetGenerator, "vocoder.pth", "model_g") + + +if __name__ == "__main__": + model = UnivNetGenerator() + + c = torch.randn(3, 100, 10) + z = torch.randn(3, 64, 10) + print(c.shape) + + y = model(c, z) + print(y.shape) + assert y.shape == torch.Size([3, 1, 2560]) + + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(pytorch_total_params) diff --git a/TTS/tts/layers/tortoise/wav2vec_alignment.py b/TTS/tts/layers/tortoise/wav2vec_alignment.py new file mode 100644 index 0000000000..47456cc5ac --- /dev/null +++ b/TTS/tts/layers/tortoise/wav2vec_alignment.py @@ -0,0 +1,150 @@ +import torch +import torchaudio +from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC + + +def max_alignment(s1, s2, skip_character="~", record=None): + """ + A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is + used to replace that character. + + Finally got to use my DP skills! + """ + if record is None: + record = {} + assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}" + if len(s1) == 0: + return "" + if len(s2) == 0: + return skip_character * len(s1) + if s1 == s2: + return s1 + if s1[0] == s2[0]: + return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record) + + take_s1_key = (len(s1), len(s2) - 1) + if take_s1_key in record: + take_s1, take_s1_score = record[take_s1_key] + else: + take_s1 = max_alignment(s1, s2[1:], skip_character, record) + take_s1_score = len(take_s1.replace(skip_character, "")) + record[take_s1_key] = (take_s1, take_s1_score) + + take_s2_key = (len(s1) - 1, len(s2)) + if take_s2_key in record: + take_s2, take_s2_score = record[take_s2_key] + else: + take_s2 = max_alignment(s1[1:], s2, skip_character, record) + take_s2_score = len(take_s2.replace(skip_character, "")) + record[take_s2_key] = (take_s2, take_s2_score) + + return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2 + + +class Wav2VecAlignment: + """ + Uses wav2vec2 to perform audio<->text alignment. + """ + + def __init__(self, device="cuda"): + self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-960h") + self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("jbetker/tacotron-symbols") + self.device = device + + def align(self, audio, expected_text, audio_sample_rate=24000): + orig_len = audio.shape[-1] + + with torch.no_grad(): + self.model = self.model.to(self.device) + audio = audio.to(self.device) + audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000) + clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) + logits = self.model(clip_norm).logits + self.model = self.model.cpu() + + logits = logits[0] + pred_string = self.tokenizer.decode(logits.argmax(-1).tolist()) + + fixed_expectation = max_alignment(expected_text.lower(), pred_string) + w2v_compression = orig_len // logits.shape[0] + expected_tokens = self.tokenizer.encode(fixed_expectation) + expected_chars = list(fixed_expectation) + if len(expected_tokens) == 1: + return [0] # The alignment is simple; there is only one token. + expected_tokens.pop(0) # The first token is a given. + expected_chars.pop(0) + + alignments = [0] + + def pop_till_you_win(): + if len(expected_tokens) == 0: + return None + popped = expected_tokens.pop(0) + popped_char = expected_chars.pop(0) + while popped_char == "~": + alignments.append(-1) + if len(expected_tokens) == 0: + return None + popped = expected_tokens.pop(0) + popped_char = expected_chars.pop(0) + return popped + + next_expected_token = pop_till_you_win() + for i, logit in enumerate(logits): + top = logit.argmax() + if next_expected_token == top: + alignments.append(i * w2v_compression) + if len(expected_tokens) > 0: + next_expected_token = pop_till_you_win() + else: + break + + pop_till_you_win() + if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)): + torch.save([audio, expected_text], "alignment_debug.pth") + assert False, ( + "Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to" + "your current working directory. Please report this along with the file so it can get fixed." + ) + + # Now fix up alignments. Anything with -1 should be interpolated. + alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable. + for i in range(len(alignments)): + if alignments[i] == -1: + for j in range(i + 1, len(alignments)): + if alignments[j] != -1: + next_found_token = j + break + for j in range(i, next_found_token): + gap = alignments[next_found_token] - alignments[i - 1] + alignments[j] = (j - i + 1) * gap // (next_found_token - i + 1) + alignments[i - 1] + + return alignments[:-1] + + def redact(self, audio, expected_text, audio_sample_rate=24000): + if "[" not in expected_text: + return audio + splitted = expected_text.split("[") + fully_split = [splitted[0]] + for spl in splitted[1:]: + assert "]" in spl, 'Every "[" character must be paired with a "]" with no nesting.' + fully_split.extend(spl.split("]")) + + # At this point, fully_split is a list of strings, with every other string being something that should be redacted. + non_redacted_intervals = [] + last_point = 0 + for i in range(len(fully_split)): + if i % 2 == 0: + end_interval = max(0, last_point + len(fully_split[i]) - 1) + non_redacted_intervals.append((last_point, end_interval)) + last_point += len(fully_split[i]) + + bare_text = "".join(fully_split) + alignments = self.align(audio, bare_text, audio_sample_rate) + + output_audio = [] + for nri in non_redacted_intervals: + start, stop = nri + output_audio.append(audio[:, alignments[start] : alignments[stop]]) + return torch.cat(output_audio, dim=-1) diff --git a/TTS/tts/layers/tortoise/xtransformers.py b/TTS/tts/layers/tortoise/xtransformers.py new file mode 100644 index 0000000000..1eb3f77269 --- /dev/null +++ b/TTS/tts/layers/tortoise/xtransformers.py @@ -0,0 +1,1259 @@ +import math +from collections import namedtuple +from functools import partial +from inspect import isfunction + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import einsum, nn + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"]) + +LayerIntermediates = namedtuple( + "Intermediates", + [ + "hiddens", + "attn_intermediates", + "past_key_values", + ], +) + + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def cast_tuple(val, depth): + return val if isinstance(val, tuple) else (val,) * depth + + +class always: + def __init__(self, val): + self.val = val + + def __call__(self, *args, **kwargs): + return self.val + + +class not_equals: + def __init__(self, val): + self.val = val + + def __call__(self, x, *args, **kwargs): + return x != self.val + + +class equals: + def __init__(self, val): + self.val = val + + def __call__(self, x, *args, **kwargs): + return x == self.val + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +# init helpers + + +def init_zero_(layer): + nn.init.constant_(layer.weight, 0.0) + if exists(layer.bias): + nn.init.constant_(layer.bias, 0.0) + + +# keyword argument helpers + + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# activations + + +class ReluSquared(nn.Module): + def forward(self, x): + return F.relu(x) ** 2 + + +# positional embeddings + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim**-0.5 + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + pos_emb = self.emb(n) + pos_emb = rearrange(pos_emb, "n d -> () n d") + return pos_emb * self.scale + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return rearrange(emb, "n d -> () n d") + + +class RelativePositionBias(nn.Module): + def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): + super().__init__() + self.scale = scale + self.causal = causal + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + if not causal: + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + else: + n = torch.max(n, torch.zeros_like(n)) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = ( + max_exact + + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long() + ) + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, qk_dots): + i, j, device = *qk_dots.shape[-2:], qk_dots.device + q_pos = torch.arange(i, dtype=torch.long, device=device) + k_pos = torch.arange(j, dtype=torch.long, device=device) + rel_pos = k_pos[None, :] - q_pos[:, None] + rp_bucket = self._relative_position_bucket( + rel_pos, causal=self.causal, num_buckets=self.num_buckets, max_distance=self.max_distance + ) + values = self.relative_attention_bias(rp_bucket) + bias = rearrange(values, "i j h -> () h i j") + return qk_dots + (bias * self.scale) + + +class AlibiPositionalBias(nn.Module): + def __init__(self, heads, **kwargs): + super().__init__() + self.heads = heads + slopes = torch.Tensor(self._get_slopes(heads)) + slopes = rearrange(slopes, "h -> () h () ()") + self.register_buffer("slopes", slopes, persistent=False) + self.register_buffer("bias", None, persistent=False) + + @staticmethod + def _get_slopes(heads): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(heads).is_integer(): + return get_slopes_power_of_2(heads) + + closest_power_of_2 = 2 ** math.floor(math.log2(heads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][: heads - closest_power_of_2] + ) + + def forward(self, qk_dots): + h, i, j, device = *qk_dots.shape[-3:], qk_dots.device + + if exists(self.bias) and self.bias.shape[-1] >= j: + return qk_dots + self.bias[..., :j] + + bias = torch.arange(j, device=device) + bias = rearrange(bias, "j -> () () () j") + bias = bias * self.slopes + + num_heads_unalibied = h - bias.shape[1] + bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied)) + + self.register_buffer("bias", bias, persistent=False) + return qk_dots + self.bias + + +class LearnedAlibiPositionalBias(AlibiPositionalBias): + def __init__(self, heads, bidirectional=False): + super().__init__(heads) + los_slopes = torch.log(self.slopes) + self.learned_logslopes = nn.Parameter(los_slopes) + + self.bidirectional = bidirectional + if self.bidirectional: + self.learned_logslopes_future = nn.Parameter(los_slopes) + + def forward(self, qk_dots): + h, i, j, device = *qk_dots.shape[-3:], qk_dots.device + + def get_slopes(param): + return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1])) + + if exists(self.bias) and self.bias.shape[-1] >= j: + bias = self.bias[..., :i, :j] + else: + i_arange = torch.arange(i, device=device) + j_arange = torch.arange(j, device=device) + bias = rearrange(j_arange, "j -> 1 1 1 j") - rearrange(i_arange, "i -> 1 1 i 1") + self.register_buffer("bias", bias, persistent=False) + + if self.bidirectional: + past_slopes = get_slopes(self.learned_logslopes) + future_slopes = get_slopes(self.learned_logslopes_future) + bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes) + else: + slopes = get_slopes(self.learned_logslopes) + bias = bias * slopes + + return qk_dots + bias + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, max_seq_len, device): + t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq) + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return rearrange(emb, "n d -> () () n d") + + +def rotate_half(x): + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs): + seq_len = t.shape[-2] + freqs = freqs[:, :, -seq_len:] + return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) + + +# norms + + +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + out = self.fn(x, **kwargs) + scale_fn = lambda t: t * self.value + + if not isinstance(out, tuple): + return scale_fn(out) + + return (scale_fn(out[0]), *out[1:]) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + out = self.fn(x, **kwargs) + rezero_fn = lambda t: t * self.g + + if not isinstance(out, tuple): + return rezero_fn(out) + + return (rezero_fn(out[0]), *out[1:]) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSScaleShiftNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + self.scale_shift_process = nn.Linear(dim * 2, dim * 2) + + def forward(self, x, norm_scale_shift_inp): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + norm = x / norm.clamp(min=self.eps) * self.g + + ss_emb = self.scale_shift_process(norm_scale_shift_inp) + scale, shift = torch.chunk(ss_emb, 2, dim=1) + h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return h + + +# residual and residual gates + + +class Residual(nn.Module): + def __init__(self, dim, scale_residual=False): + super().__init__() + self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + + def forward(self, x, residual): + if exists(self.residual_scale): + residual = residual * self.residual_scale + + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim, scale_residual=False): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + + def forward(self, x, residual): + if exists(self.residual_scale): + residual = residual * self.residual_scale + + gated_output = self.gru(rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d")) + + return gated_output.reshape_as(x) + + +# token shifting + + +def shift(t, amount, mask=None): + if amount == 0: + return t + + if exists(mask): + t = t.masked_fill(~mask[..., None], 0.0) + + return F.pad(t, (0, 0, amount, -amount), value=0.0) + + +class ShiftTokens(nn.Module): + def __init__(self, shifts, fn): + super().__init__() + self.fn = fn + self.shifts = tuple(shifts) + + def forward(self, x, **kwargs): + mask = kwargs.get("mask", None) + shifts = self.shifts + segments = len(shifts) + feats_per_shift = x.shape[-1] // segments + splitted = x.split(feats_per_shift, dim=-1) + segments_to_shift, rest = splitted[:segments], splitted[segments:] + segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts))) + x = torch.cat((*segments_to_shift, *rest), dim=-1) + return self.fn(x, **kwargs) + + +# feedforward + + +class GLU(nn.Module): + def __init__(self, dim_in, dim_out, activation): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * self.act(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out=None, + mult=4, + glu=False, + relu_squared=False, + post_act_ln=False, + dropout=0.0, + zero_init_output=False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + activation = ReluSquared() if relu_squared else nn.GELU() + + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), activation) if not glu else GLU(dim, inner_dim, activation) + ) + + self.net = nn.Sequential( + project_in, + nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out), + ) + + # init last linear layer to 0 + if zero_init_output: + init_zero_(self.net[-1]) + + def forward(self, x): + return self.net(x) + + +# attention. + + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + talking_heads=False, + head_scale=False, + collab_heads=False, + collab_compression=0.3, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0.0, + on_attn=False, + gate_values=False, + zero_init_output=False, + max_attend_past=None, + qk_norm=False, + scale_init_value=None, + rel_pos_bias=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + ): + super().__init__() + self.scale = dim_head**-0.5 + + self.heads = heads + self.causal = causal + self.max_attend_past = max_attend_past + + qk_dim = v_dim = dim_head * heads + + # collaborative heads + self.collab_heads = collab_heads + if self.collab_heads: + qk_dim = int(collab_compression * qk_dim) + self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim)) + + self.to_q = nn.Linear(dim, qk_dim, bias=False) + self.to_k = nn.Linear(dim, qk_dim, bias=False) + self.to_v = nn.Linear(dim, v_dim, bias=False) + + self.dropout = nn.Dropout(dropout) + + # add GLU gating for aggregated values, from alphafold2 + self.to_v_gate = None + if gate_values: + self.to_v_gate = nn.Linear(dim, v_dim) + nn.init.constant_(self.to_v_gate.weight, 0) + nn.init.constant_(self.to_v_gate.bias, 1) + + # cosine sim attention + self.qk_norm = qk_norm + if qk_norm: + scale_init_value = default( + scale_init_value, -3 + ) # if not provided, initialize as though it were sequence length of 1024 + self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # head scaling + self.head_scale = head_scale + if head_scale: + self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim) + + self.rel_pos_bias = rel_pos_bias + if rel_pos_bias: + assert ( + rel_pos_num_buckets <= rel_pos_max_distance + ), "number of relative position buckets must be less than the relative position max distance" + self.rel_pos = RelativePositionBias( + scale=dim_head**0.5, + causal=causal, + heads=heads, + num_buckets=rel_pos_num_buckets, + max_distance=rel_pos_max_distance, + ) + + # init output projection 0 + if zero_init_output: + init_zero_(self.to_out) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + attn_mask=None, + sinusoidal_emb=None, + rotary_pos_emb=None, + prev_attn=None, + mem=None, + layer_past=None, + ): + b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = ( + *x.shape, + self.heads, + self.talking_heads, + self.collab_heads, + self.head_scale, + self.scale, + x.device, + exists(context), + ) + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + if not collab_heads: + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + else: + q = einsum("b i d, h d -> b h i d", q, self.collab_mixing) + k = rearrange(k, "b n d -> b () n d") + v = rearrange(v, "b n (h d) -> b h n d", h=h) + + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat([past_key, k], dim=-2) + v = torch.cat([past_value, v], dim=-2) + k_cache = k + v_cache = v + + if exists(rotary_pos_emb) and not has_context: + l = rotary_pos_emb.shape[-1] + (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) + ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) + q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, "b i -> b () i ()") + k_mask = rearrange(k_mask, "b j -> b () () j") + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + if collab_heads: + k = k.expand(-1, h, -1, -1) + + if self.qk_norm: + q, k = map(l2norm, (q, k)) + scale = 1 / (self.scale.exp().clamp(min=1e-2)) + + dots = einsum("b h i d, b h j d -> b h i j", q, k) * scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots.clone() + + if talking_heads: + dots = einsum("b h i j, h k -> b k i j", dots, self.pre_softmax_proj).contiguous() + + if self.rel_pos_bias: + dots = self.rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if exists(attn_mask): + assert ( + 2 <= attn_mask.ndim <= 4 + ), "attention mask must have greater than 2 dimensions but less than or equal to 4" + if attn_mask.ndim == 2: + attn_mask = rearrange(attn_mask, "i j -> () () i j") + elif attn_mask.ndim == 3: + attn_mask = rearrange(attn_mask, "h i j -> () h i j") + dots.masked_fill_(~attn_mask, mask_value) + + if exists(self.max_attend_past): + i, j = dots.shape[-2:] + range_q = torch.arange(j - i, j, device=device) + range_k = torch.arange(j, device=device) + dist = rearrange(range_q, "i -> () () i ()") - rearrange(range_k, "j -> () () () j") + mask = dist > self.max_attend_past + dots.masked_fill_(mask, mask_value) + del mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn.clone() + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum("b h i j, h k -> b k i j", attn, self.post_softmax_proj).contiguous() + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + if head_scale: + out = out * self.head_scale_params + + out = rearrange(out, "b h n d -> b n (h d)") + + if exists(self.to_v_gate): + gates = self.to_v_gate(x) + out = out * gates.sigmoid() + + intermediates = Intermediates(pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn) + + return self.to_out(out), intermediates, k_cache, v_cache + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rms_scaleshift_norm=False, + use_rmsnorm=False, + use_rezero=False, + alibi_pos_bias=False, + alibi_num_heads=None, + alibi_learned=False, + position_infused_attn=False, + rotary_pos_emb=False, + rotary_emb_dim=None, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + scale_residual=False, + shift_tokens=0, + sandwich_norm=False, + use_qk_norm_attn=False, + qk_norm_attn_seq_len=None, + zero_init_branch_output=False, + **kwargs, + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs) + attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs) + + dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + self.causal = causal + + rel_pos_bias = "rel_pos_bias" in attn_kwargs + self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + + rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) + self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None + + assert not ( + alibi_pos_bias and rel_pos_bias + ), "you can only choose Alibi positional bias or T5 relative positional bias, not both" + + if alibi_pos_bias: + alibi_num_heads = default(alibi_num_heads, heads) + assert alibi_num_heads <= heads, "number of ALiBi heads must be less than the total number of heads" + alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias + self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal) + else: + self.rel_pos = None + + assert not (not pre_norm and sandwich_norm), "sandwich norm cannot be used when not using prenorm" + self.pre_norm = pre_norm + self.sandwich_norm = sandwich_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + self.cross_attend = cross_attend + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ("a", "c", "f") + elif cross_attend and only_cross: + default_block = ("c", "f") + else: + default_block = ("a", "f") + + if macaron: + default_block = ("f",) + default_block + + # qk normalization + + if use_qk_norm_attn: + attn_scale_init_value = ( + -math.log(math.log2(qk_norm_attn_seq_len**2 - qk_norm_attn_seq_len)) + if exists(qk_norm_attn_seq_len) + else None + ) + attn_kwargs = {**attn_kwargs, "qk_norm": True, "scale_init_value": attn_scale_init_value} + + # zero init + + if zero_init_branch_output: + attn_kwargs = {**attn_kwargs, "zero_init_output": True} + ff_kwargs = {**ff_kwargs, "zero_init_output": True} + + # calculate layer block order + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, "par ratio out of range" + default_block = tuple(filter(not_equals("f"), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, "default block is too large for par_ratio" + par_block = default_block + ("f",) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ("f",) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, "sandwich coefficient should be less than the depth" + layer_types = ("a",) * sandwich_coef + default_block * (depth - sandwich_coef) + ("f",) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals("a"), layer_types))) + + # calculate token shifting + + shift_tokens = cast_tuple(shift_tokens, len(layer_types)) + + # iterate and construct layers + + for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): + is_last_layer = ind == (len(self.layer_types) - 1) + + if layer_type == "a": + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == "c": + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == "f": + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f"invalid layer type {layer_type}") + + if layer_shift_tokens > 0: + shift_range_upper = layer_shift_tokens + 1 + shift_range_lower = -layer_shift_tokens if not causal else 0 + layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer) + + if exists(branch_fn): + layer = branch_fn(layer) + + residual_fn = GRUGating if gate_residual else Residual + residual = residual_fn(dim, scale_residual=scale_residual) + + layer_uses_qk_norm = use_qk_norm_attn and layer_type in ("a", "c") + + pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None + post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None + post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None + + norms = nn.ModuleList([pre_branch_norm, post_branch_norm, post_main_norm]) + + self.layers.append(nn.ModuleList([norms, layer, residual])) + + def forward( + self, + x, + context=None, + full_context=None, # for passing a list of hidden states from an encoder + mask=None, + context_mask=None, + attn_mask=None, + mems=None, + return_hiddens=False, + norm_scale_shift_inp=None, + past_key_values=None, + expected_seq_len=None, + ): + assert not ( + self.cross_attend ^ (exists(context) or exists(full_context)) + ), "context must be passed in if cross_attend is set to True" + assert context is None or full_context is None, "only one of full_context or context can be provided" + + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + norm_args = {} + if exists(norm_scale_shift_inp): + norm_args["norm_scale_shift_inp"] = norm_scale_shift_inp + + rotary_pos_emb = None + if exists(self.rotary_pos_emb): + if not self.training and self.causal: + assert ( + expected_seq_len is not None + ), "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`" + elif expected_seq_len is None: + expected_seq_len = 0 + seq_len = x.shape[1] + if past_key_values is not None: + seq_len += past_key_values[0][0].shape[-2] + max_rotary_emb_length = max( + list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len] + ) + rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) + + present_key_values = [] + cross_attn_count = 0 + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + if layer_type == "a": + layer_mem = mems.pop(0) if mems else None + + residual = x + + pre_branch_norm, post_branch_norm, post_main_norm = norm + + if exists(pre_branch_norm): + x = pre_branch_norm(x, **norm_args) + + if layer_type == "a" or layer_type == "c": + if past_key_values is not None: + layer_kv = past_key_values.pop(0) + layer_past = tuple(s.to(x.device) for s in layer_kv) + else: + layer_past = None + + if layer_type == "a": + out, inter, k, v = block( + x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem, layer_past + ) + elif layer_type == "c": + if exists(full_context): + out, inter, k, v = block( + x, + full_context[cross_attn_count], + mask, + context_mask, + None, + None, + None, + prev_attn, + None, + layer_past, + ) + else: + out, inter, k, v = block( + x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past + ) + elif layer_type == "f": + out = block(x) + + if layer_type == "a" or layer_type == "c" and present_key_values is not None: + present_key_values.append((k.detach(), v.detach())) + + if exists(post_branch_norm): + out = post_branch_norm(out, **norm_args) + + x = residual_fn(out, residual) + + if layer_type in ("a", "c"): + intermediates.append(inter) + + if layer_type == "a" and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == "c" and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if exists(post_main_norm): + x = post_main_norm(x, **norm_args) + + if layer_type == "c": + cross_attn_count += 1 + + if layer_type == "f": + hiddens.append(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, attn_intermediates=intermediates, past_key_values=present_key_values + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert "causal" not in kwargs, "cannot set causality on encoder" + super().__init__(causal=False, **kwargs) + + +class Decoder(AttentionLayers): + def __init__(self, **kwargs): + assert "causal" not in kwargs, "cannot set causality on decoder" + super().__init__(causal=True, **kwargs) + + +class CrossAttender(AttentionLayers): + def __init__(self, **kwargs): + super().__init__(cross_attend=True, only_cross=True, **kwargs) + + +class ViTransformerWrapper(nn.Module): + def __init__(self, *, image_size, patch_size, attn_layers, num_classes=None, dropout=0.0, emb_dropout=0.0): + super().__init__() + assert isinstance(attn_layers, Encoder), "attention layers must be an Encoder" + assert image_size % patch_size == 0, "image dimensions must be divisible by the patch size" + dim = attn_layers.dim + num_patches = (image_size // patch_size) ** 2 + patch_dim = 3 * patch_size**2 + + self.patch_size = patch_size + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.patch_to_embedding = nn.Linear(patch_dim, dim) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None + + def forward(self, img, return_embeddings=False): + p = self.patch_size + + x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p) + x = self.patch_to_embedding(x) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embedding[:, : (n + 1)] + x = self.dropout(x) + + x = self.attn_layers(x) + x = self.norm(x) + + if not exists(self.mlp_head) or return_embeddings: + return x + + return self.mlp_head(x[:, 0]) + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0.0, + shift_mem_down=0, + emb_dropout=0.0, + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True, + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder" + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.shift_mem_down = shift_mem_down + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + def init_(self): + nn.init.kaiming_normal_(self.token_emb.weight) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_hiddens=False, + return_attn=False, + mems=None, + use_cache=False, + **kwargs, + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x = x + self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, "n d -> b n d", b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + if self.shift_mem_down and exists(mems): + mems_l, mems_r = mems[: self.shift_mem_down], mems[self.shift_mem_down :] + mems = [*mems_r, *mems_l] + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_hiddens: + hiddens = intermediates.hiddens + return out, hiddens + + res = [out] + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + res.append(attn_maps) + if use_cache: + res.append(intermediates.past_key_values) + + if len(res) > 1: + return tuple(res) + return res[0] + + +class ContinuousTransformerWrapper(nn.Module): + def __init__( + self, *, max_seq_len, attn_layers, dim_in=None, dim_out=None, emb_dim=None, emb_dropout=0.0, use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder" + + dim = attn_layers.dim + + self.max_seq_len = max_seq_len + + self.pos_emb = ( + AbsolutePositionalEmbedding(dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() + + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() + + def forward(self, x, return_embeddings=False, mask=None, return_attn=False, mems=None, use_cache=False, **kwargs): + b, n, _, device = *x.shape, x.device + + x = self.project_in(x) + x = x + self.pos_emb(x) + x = self.emb_dropout(x) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + out = self.project_out(x) if not return_embeddings else x + + res = [out] + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + res.append(attn_maps) + if use_cache: + res.append(intermediates.past_key_values) + + if len(res) > 1: + return tuple(res) + return res[0] diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index d76a3bebee..2bd2e5f087 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -10,5 +10,5 @@ def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) MyModel = find_module("TTS.tts.models", config.base_model.lower()) else: MyModel = find_module("TTS.tts.models", config.model.lower()) - model = MyModel.init_from_config(config, samples) + model = MyModel.init_from_config(config=config, samples=samples) return model diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py new file mode 100644 index 0000000000..4d558a122c --- /dev/null +++ b/TTS/tts/models/tortoise.py @@ -0,0 +1,900 @@ +import os +import random +from contextlib import contextmanager +from dataclasses import dataclass +from time import time + +import torch +import torch.nn.functional as F +import torchaudio +from coqpit import Coqpit +from tqdm import tqdm + +from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram +from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, load_voice, wav_to_univnet_mel +from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice +from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead +from TTS.tts.layers.tortoise.clvp import CLVP +from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps +from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts +from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter +from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.tortoise.vocoder import VocConf, VocType +from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment +from TTS.tts.models.base_tts import BaseTTS + + +def pad_or_truncate(t, length): + """ + Utility function for forcing to have the specified sequence length, whether by clipping it or padding it with 0s. + """ + tp = t[..., :length] + if t.shape[-1] == length: + tp = t + elif t.shape[-1] < length: + tp = F.pad(t, (0, length - t.shape[-1])) + return tp + + +def deterministic_state(seed=None): + """ + Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be + reproduced. + """ + seed = int(time()) if seed is None else seed + torch.manual_seed(seed) + random.seed(seed) + # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary. + # torch.use_deterministic_algorithms(True) + + return seed + + +def load_discrete_vocoder_diffuser( + trained_diffusion_steps=4000, + desired_diffusion_steps=200, + cond_free=True, + cond_free_k=1, + sampler="ddim", +): + """ + Helper function to load a GaussianDiffusion instance configured for use as a vocoder. + """ + return SpacedDiffusion( + use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), + model_mean_type="epsilon", + model_var_type="learned_range", + loss_type="mse", + betas=get_named_beta_schedule("linear", trained_diffusion_steps), + conditioning_free=cond_free, + conditioning_free_k=cond_free_k, + sampler=sampler, + ) + + +def format_conditioning(clip, cond_length=132300, device="cuda"): + """ + Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. + """ + gap = clip.shape[-1] - cond_length + if gap < 0: + clip = F.pad(clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + clip = clip[:, rand_start : rand_start + cond_length] + mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0) + return mel_clip.unsqueeze(0).to(device) + + +def fix_autoregressive_output(codes, stop_token, complain=True): + """ + This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was + trained on and what the autoregressive code generator creates (which has no padding or end). + This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with + a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE + and copying out the last few codes. + + Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar. + """ + # Strip off the autoregressive stop token and add padding. + stop_token_indices = (codes == stop_token).nonzero() + if len(stop_token_indices) == 0: + if complain: + print( + "No stop tokens found in one of the generated voice clips. This typically means the spoken audio is " + "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, " + "try breaking up your input text." + ) + return codes + codes[stop_token_indices] = 83 + stm = stop_token_indices.min().item() + codes[stm:] = 83 + if stm - 3 < codes.shape[0]: + codes[-3] = 45 + codes[-2] = 45 + codes[-1] = 248 + return codes + + +def do_spectrogram_diffusion( + diffusion_model, + diffuser, + latents, + conditioning_latents, + temperature=1, + verbose=True, +): + """ + Uses the specified diffusion model to convert discrete codes into a spectrogram. + """ + with torch.no_grad(): + output_seq_len = ( + latents.shape[1] * 4 * 24000 // 22050 + ) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_shape = (latents.shape[0], 100, output_seq_len) + precomputed_embeddings = diffusion_model.timestep_independent( + latents, conditioning_latents, output_seq_len, False + ) + + noise = torch.randn(output_shape, device=latents.device) * temperature + mel = diffuser.sample_loop( + diffusion_model, + output_shape, + noise=noise, + model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, + progress=verbose, + ) + return denormalize_tacotron_mel(mel)[:, :, :output_seq_len] + + +def classify_audio_clip(clip, model_dir): + """ + Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise. + :param clip: torch tensor containing audio waveform data (get it from load_audio) + :return: True if the clip was classified as coming from Tortoise and false if it was classified as real. + """ + classifier = AudioMiniEncoderWithClassifierHead( + 2, + spec_dim=1, + embedding_dim=512, + depth=5, + downsample_factor=4, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + base_channels=32, + dropout=0, + kernel_size=5, + distribute_zero_label=False, + ) + classifier.load_state_dict(torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"))) + clip = clip.cpu().unsqueeze(0) + results = F.softmax(classifier(clip), dim=-1) + return results[0][0] + + +def pick_best_batch_size_for_gpu(): + """ + Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give + you a good shot. + """ + if torch.cuda.is_available(): + _, available = torch.cuda.mem_get_info() + availableGb = available / (1024**3) + batch_size = 1 + if availableGb > 14: + batch_size = 16 + elif availableGb > 10: + batch_size = 8 + elif availableGb > 7: + batch_size = 4 + return batch_size + + +@dataclass +class TortoiseAudioConfig(Coqpit): + sample_rate: int = 22050 + diffusion_sample_rate: int = 24000 + output_sample_rate: int = 24000 + + +@dataclass +class TortoiseArgs(Coqpit): + """A dataclass to represent Tortoise model arguments that define the model structure. + + Args: + autoregressive_batch_size (int): The size of the auto-regressive batch. + enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. + high_vram (bool, optional): Whether to use high VRAM. Defaults to False. + kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. + ar_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. + clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. + diff_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. + num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. + vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet. + + For UnifiedVoice model: + ar_max_mel_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. + ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402. + ar_max_conditioning_inputs (int, optional): The maximum conditioning inputs for the autoregressive model. Defaults to 2. + ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30. + ar_model_dim (int, optional): The model dimension for the autoregressive model. Defaults to 1024. + ar_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16. + ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255. + ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255. + ar_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False. + ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False. + + For DiffTTS model: + diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024. + diff_num_layers (int, optional): The number of layers for the DiffTTS model. Defaults to 10. + diff_in_channels (int, optional): The input channels for the DiffTTS model. Defaults to 100. + diff_out_channels (int, optional): The output channels for the DiffTTS model. Defaults to 200. + diff_in_latent_channels (int, optional): The input latent channels for the DiffTTS model. Defaults to 1024. + diff_in_tokens (int, optional): The input tokens for the DiffTTS model. Defaults to 8193. + diff_dropout (int, optional): The dropout percentage for the DiffTTS model. Defaults to 0. + diff_use_fp16 (bool, optional): Whether to use fp16 for the DiffTTS model. Defaults to False. + diff_num_heads (int, optional): The number of heads for the DiffTTS model. Defaults to 16. + diff_layer_drop (int, optional): The layer dropout percentage for the DiffTTS model. Defaults to 0. + diff_unconditioned_percentage (int, optional): The percentage of unconditioned inputs for the DiffTTS model. Defaults to 0. + + For ConditionalLatentVariablePerseq model: + clvp_dim_text (int): The dimension of the text input for the CLVP module. Defaults to 768. + clvp_dim_speech (int): The dimension of the speech input for the CLVP module. Defaults to 768. + clvp_dim_latent (int): The dimension of the latent representation for the CLVP module. Defaults to 768. + clvp_num_text_tokens (int): The number of text tokens used by the CLVP module. Defaults to 256. + clvp_text_enc_depth (int): The depth of the text encoder in the CLVP module. Defaults to 20. + clvp_text_seq_len (int): The maximum sequence length of the text input for the CLVP module. Defaults to 350. + clvp_text_heads (int): The number of attention heads used by the text encoder in the CLVP module. Defaults to 12. + clvp_num_speech_tokens (int): The number of speech tokens used by the CLVP module. Defaults to 8192. + clvp_speech_enc_depth (int): The depth of the speech encoder in the CLVP module. Defaults to 20. + clvp_speech_heads (int): The number of attention heads used by the speech encoder in the CLVP module. Defaults to 12. + clvp_speech_seq_len (int): The maximum sequence length of the speech input for the CLVP module. Defaults to 430. + clvp_use_xformers (bool): A flag indicating whether the model uses transformers in the CLVP module. Defaults to True. + duration_const (int): A constant value used in the model. Defaults to 102400. + """ + + autoregressive_batch_size: int = 1 + enable_redaction: bool = True + high_vram: bool = False + kv_cache: bool = True + ar_checkpoint: str = None + clvp_checkpoint: str = None + diff_checkpoint: str = None + num_chars: int = 255 + vocoder: VocType = VocConf.Univnet + + # UnifiedVoice params + ar_max_mel_tokens: int = 604 + ar_max_text_tokens: int = 402 + ar_max_conditioning_inputs: int = 2 + ar_layers: int = 30 + ar_model_dim: int = 1024 + ar_heads: int = 16 + ar_number_text_tokens: int = 255 + ar_start_text_token: int = 255 + ar_checkpointing: bool = False + ar_train_solo_embeddings: bool = False + + # DiffTTS params + diff_model_channels: int = 1024 + diff_num_layers: int = 10 + diff_in_channels: int = 100 + diff_out_channels: int = 200 + diff_in_latent_channels: int = 1024 + diff_in_tokens: int = 8193 + diff_dropout: int = 0 + diff_use_fp16: bool = False + diff_num_heads: int = 16 + diff_layer_drop: int = 0 + diff_unconditioned_percentage: int = 0 + + # clvp params + clvp_dim_text: int = 768 + clvp_dim_speech: int = 768 + clvp_dim_latent: int = 768 + clvp_num_text_tokens: int = 256 + clvp_text_enc_depth: int = 20 + clvp_text_seq_len: int = 350 + clvp_text_heads: int = 12 + clvp_num_speech_tokens: int = 8192 + clvp_speech_enc_depth: int = 20 + clvp_speech_heads: int = 12 + clvp_speech_seq_len: int = 430 + clvp_use_xformers: bool = True + # constants + duration_const: int = 102400 + + +class Tortoise(BaseTTS): + """Tortoise model class. + + Currently only supports inference. + + Examples: + >>> from TTS.tts.configs.tortoise_config import TortoiseConfig + >>> from TTS.tts.models.tortoise import Tortoise + >>> config = TortoiseConfig() + >>> model = Tortoise.inif_from_config(config) + >>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True) + """ + + def __init__(self, config: Coqpit): + super().__init__(config, ap=None, tokenizer=None) + self.config = config + self.ar_checkpoint = self.args.ar_checkpoint + self.diff_checkpoint = self.args.diff_checkpoint # TODO: check if this is even needed + self.models_dir = config.model_dir + self.autoregressive_batch_size = ( + pick_best_batch_size_for_gpu() + if self.args.autoregressive_batch_size is None + else self.args.autoregressive_batch_size + ) + self.enable_redaction = self.args.enable_redaction + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if self.enable_redaction: + self.aligner = Wav2VecAlignment() + + self.tokenizer = VoiceBpeTokenizer() + + self.autoregressive = UnifiedVoice( + max_mel_tokens=self.args.ar_max_mel_tokens, + max_text_tokens=self.args.ar_max_text_tokens, + max_conditioning_inputs=self.args.ar_max_conditioning_inputs, + layers=self.args.ar_layers, + model_dim=self.args.ar_model_dim, + heads=self.args.ar_heads, + number_text_tokens=self.args.ar_number_text_tokens, + start_text_token=self.args.ar_start_text_token, + checkpointing=self.args.ar_checkpointing, + train_solo_embeddings=self.args.ar_train_solo_embeddings, + ).cpu() + + self.diffusion = DiffusionTts( + model_channels=self.args.diff_model_channels, + num_layers=self.args.diff_num_layers, + in_channels=self.args.diff_in_channels, + out_channels=self.args.diff_out_channels, + in_latent_channels=self.args.diff_in_latent_channels, + in_tokens=self.args.diff_in_tokens, + dropout=self.args.diff_dropout, + use_fp16=self.args.diff_use_fp16, + num_heads=self.args.diff_num_heads, + layer_drop=self.args.diff_layer_drop, + unconditioned_percentage=self.args.diff_unconditioned_percentage, + ).cpu() + + self.clvp = CLVP( + dim_text=self.args.clvp_dim_text, + dim_speech=self.args.clvp_dim_speech, + dim_latent=self.args.clvp_dim_latent, + num_text_tokens=self.args.clvp_num_text_tokens, + text_enc_depth=self.args.clvp_text_enc_depth, + text_seq_len=self.args.clvp_text_seq_len, + text_heads=self.args.clvp_text_heads, + num_speech_tokens=self.args.clvp_num_speech_tokens, + speech_enc_depth=self.args.clvp_speech_enc_depth, + speech_heads=self.args.clvp_speech_heads, + speech_seq_len=self.args.clvp_speech_seq_len, + use_xformers=self.args.clvp_use_xformers, + ).cpu() + + self.vocoder = self.args.vocoder.value.constructor().cpu() + + # Random latent generators (RLGs) are loaded lazily. + self.rlg_auto = None + self.rlg_diffusion = None + + if self.args.high_vram: + self.autoregressive = self.autoregressive.to(self.device) + self.diffusion = self.diffusion.to(self.device) + self.clvp = self.clvp.to(self.device) + self.vocoder = self.vocoder.to(self.device) + self.high_vram = self.args.high_vram + + @contextmanager + def temporary_cuda(self, model): + if self.high_vram: + yield model + else: + m = model.to(self.device) + yield m + m = model.cpu() + + def get_conditioning_latents( + self, + voice_samples, + return_mels=False, + latent_averaging_mode=0, + original_tortoise=False, + ): + """ + Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). + These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic + properties. + :param voice_samples: List of arbitrary reference clips, which should be *pairs* of torch tensors containing arbitrary kHz waveform data. + :param latent_averaging_mode: 0/1/2 for following modes: + 0 - latents will be generated as in original tortoise, using ~4.27s from each voice sample, averaging latent across all samples + 1 - latents will be generated using (almost) entire voice samples, averaged across all the ~4.27s chunks + 2 - latents will be generated using (almost) entire voice samples, averaged per voice sample + """ + assert latent_averaging_mode in [ + 0, + 1, + 2, + ], "latent_averaging mode has to be one of (0, 1, 2)" + + with torch.no_grad(): + voice_samples = [[v.to(self.device) for v in ls] for ls in voice_samples] + + auto_conds = [] + for ls in voice_samples: + auto_conds.append(format_conditioning(ls[0], device=self.device)) + auto_conds = torch.stack(auto_conds, dim=1) + with self.temporary_cuda(self.autoregressive) as ar: + auto_latent = ar.get_conditioning(auto_conds) + + diffusion_conds = [] + + DURS_CONST = self.args.duration_const + for ls in voice_samples: + # The diffuser operates at a sample rate of 24000 (except for the latent inputs) + sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1] + if latent_averaging_mode == 0: + sample = pad_or_truncate(sample, DURS_CONST) + cond_mel = wav_to_univnet_mel( + sample.to(self.device), + do_normalization=False, + device=self.device, + ) + diffusion_conds.append(cond_mel) + else: + from math import ceil + + if latent_averaging_mode == 2: + temp_diffusion_conds = [] + for chunk in range(ceil(sample.shape[1] / DURS_CONST)): + current_sample = sample[:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST] + current_sample = pad_or_truncate(current_sample, DURS_CONST) + cond_mel = wav_to_univnet_mel( + current_sample.to(self.device), + do_normalization=False, + device=self.device, + ) + if latent_averaging_mode == 1: + diffusion_conds.append(cond_mel) + elif latent_averaging_mode == 2: + temp_diffusion_conds.append(cond_mel) + if latent_averaging_mode == 2: + diffusion_conds.append(torch.stack(temp_diffusion_conds).mean(0)) + diffusion_conds = torch.stack(diffusion_conds, dim=1) + + with self.temporary_cuda(self.diffusion) as diffusion: + diffusion_latent = diffusion.get_conditioning(diffusion_conds) + + if return_mels: + return auto_latent, diffusion_latent, auto_conds, diffusion_conds + return auto_latent, diffusion_latent + + def get_random_conditioning_latents(self): + # Lazy-load the RLG models. + if self.rlg_auto is None: + self.rlg_auto = RandomLatentConverter(1024).eval() + self.rlg_auto.load_state_dict( + torch.load( + os.path.join(self.models_dir, "rlg_auto.pth"), + map_location=torch.device("cpu"), + ) + ) + self.rlg_diffusion = RandomLatentConverter(2048).eval() + self.rlg_diffusion.load_state_dict( + torch.load( + os.path.join(self.models_dir, "rlg_diffuser.pth"), + map_location=torch.device("cpu"), + ) + ) + with torch.no_grad(): + return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) + + def synthesize(self, text, config, speaker_id="random", extra_voice_dirs=None, **kwargs): + """Synthesize speech with the given input text. + + Args: + text (str): Input text. + config (TortoiseConfig): Config with inference parameters. + speaker_id (str): One of the available speaker names. If `random`, it generates a random speaker. + extra_voice_dirs (List[str]): List of paths that host reference audio files for speakers. Defaults to None. + **kwargs: Inference settings. See `inference()`. + + Returns: + A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, + `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` + as latents used at inference. + + """ + if extra_voice_dirs is not None: + extra_voice_dirs = [extra_voice_dirs] + voice_samples, conditioning_latents = load_voice(speaker_id, extra_voice_dirs) + else: + voice_samples, conditioning_latents = load_voice(speaker_id) + + outputs = self.inference_with_config( + text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents, **kwargs + ) + + return_dict = { + "wav": outputs["wav"], + "deterministic_seed": outputs["deterministic_seed"], + "text_inputs": outputs["text"], + "voice_samples": outputs["voice_samples"], + "conditioning_latents": outputs["conditioning_latents"], + } + + return return_dict + + def inference_with_config(self, text, config, **kwargs): + """ + inference with config + #TODO describe in detail + """ + # Use generally found best tuning knobs for generation. + settings = { + "temperature": config.temperature, + "length_penalty": config.length_penalty, + "repetition_penalty": config.repetition_penalty, + "top_p": config.top_p, + "cond_free_k": config.cond_free_k, + "diffusion_temperature": config.diffusion_temperature, + "sampler": config.sampler, + } + # Presets are defined here. + presets = { + "single_sample": { + "num_autoregressive_samples": 8, + "diffusion_iterations": 10, + "sampler": "ddim", + }, + "ultra_fast": { + "num_autoregressive_samples": 16, + "diffusion_iterations": 10, + "sampler": "ddim", + }, + "ultra_fast_old": { + "num_autoregressive_samples": 16, + "diffusion_iterations": 30, + "cond_free": False, + }, + "very_fast": { + "num_autoregressive_samples": 32, + "diffusion_iterations": 30, + "sampler": "dpm++2m", + }, + "fast": { + "num_autoregressive_samples": 5, + "diffusion_iterations": 50, + "sampler": "ddim", + }, + "fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80}, + "standard": { + "num_autoregressive_samples": 5, + "diffusion_iterations": 200, + }, + "high_quality": { + "num_autoregressive_samples": 256, + "diffusion_iterations": 400, + }, + } + if "preset" in kwargs: + settings.update(presets[kwargs["preset"]]) + kwargs.pop("preset") + settings.update(kwargs) # allow overriding of preset settings with kwargs + return self.inference(text, **settings) + + def inference( + self, + text, + voice_samples=None, + conditioning_latents=None, + k=1, + verbose=True, + use_deterministic_seed=None, + return_deterministic_state=False, + latent_averaging_mode=0, + # autoregressive generation parameters follow + num_autoregressive_samples=16, + temperature=0.8, + length_penalty=1, + repetition_penalty=2.0, + top_p=0.8, + max_mel_tokens=500, + # diffusion generation parameters follow + diffusion_iterations=100, + cond_free=True, + cond_free_k=2, + diffusion_temperature=1.0, + sampler="ddim", + half=True, + original_tortoise=False, + **hf_generate_kwargs, + ): + """ + This function produces an audio clip of the given text being spoken with the given reference voice. + + Args: + text: (str) Text to be spoken. + voice_samples: (List[Tuple[torch.Tensor]]) List of an arbitrary number of reference clips, which should be tuple-pairs + of torch tensors containing arbitrary kHz waveform data. + conditioning_latents: (Tuple[autoregressive_conditioning_latent, diffusion_conditioning_latent]) A tuple of + (autoregressive_conditioning_latent, diffusion_conditioning_latent), which can be provided in lieu + of voice_samples. This is ignored unless `voice_samples=None`. Conditioning latents can be retrieved + via `get_conditioning_latents()`. + k: (int) The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned. + latent_averaging_mode: (int) 0/1/2 for following modes: + 0 - latents will be generated as in original tortoise, using ~4.27s from each voice sample, averaging latent across all samples + 1 - latents will be generated using (almost) entire voice samples, averaged across all the ~4.27s chunks + 2 - latents will be generated using (almost) entire voice samples, averaged per voice sample + verbose: (bool) Whether or not to print log messages indicating the progress of creating a clip. Default=true. + num_autoregressive_samples: (int) Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + temperature: (float) The softmax temperature of the autoregressive model. + length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. + repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce + the incidence of long silences or "uhhhhhhs", etc. + top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs. + max_mel_tokens: (int) Restricts the output length. (0,600] integer. Each unit is 1/20 of a second. + typical_sampling: (bool) Turns typical sampling on or off. This sampling mode is discussed in this paper: https://arxiv.org/abs/2202.00666 + I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but could use some tuning. + typical_mass: (float) The typical_mass parameter from the typical_sampling algorithm. + diffusion_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively + refine the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, however. + cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for + each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output of the two + is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and dramatically improves realism. + cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive transformer. + Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + + Returns: + Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + deterministic_seed = deterministic_state(seed=use_deterministic_seed) + + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + assert ( + text_tokens.shape[-1] < 400 + ), "Too much text provided. Break the text up into separate segments and re-try inference." + + if voice_samples is not None: + ( + auto_conditioning, + diffusion_conditioning, + _, + _, + ) = self.get_conditioning_latents( + voice_samples, + return_mels=True, + latent_averaging_mode=latent_averaging_mode, + original_tortoise=original_tortoise, + ) + elif conditioning_latents is not None: + auto_conditioning, diffusion_conditioning = conditioning_latents + else: + ( + auto_conditioning, + diffusion_conditioning, + ) = self.get_random_conditioning_latents() + auto_conditioning = auto_conditioning.to(self.device) + diffusion_conditioning = diffusion_conditioning.to(self.device) + + diffuser = load_discrete_vocoder_diffuser( + desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k, sampler=sampler + ) + + # in the case of single_sample, + orig_batch_size = self.autoregressive_batch_size + while num_autoregressive_samples % self.autoregressive_batch_size: + self.autoregressive_batch_size //= 2 + with torch.no_grad(): + samples = [] + num_batches = num_autoregressive_samples // self.autoregressive_batch_size + stop_mel_token = self.autoregressive.stop_mel_token + calm_token = ( + 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + ) + self.autoregressive = self.autoregressive.to(self.device) + if verbose: + print("Generating autoregressive samples..") + with self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=half + ): + for b in tqdm(range(num_batches), disable=not verbose): + codes = autoregressive.inference_speech( + auto_conditioning, + text_tokens, + do_sample=True, + top_p=top_p, + temperature=temperature, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **hf_generate_kwargs, + ) + padding_needed = max_mel_tokens - codes.shape[1] + codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) + samples.append(codes) + self.autoregressive_batch_size = orig_batch_size # in the case of single_sample + + clip_results = [] + with self.temporary_cuda(self.clvp) as clvp, torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=half + ): + for batch in tqdm(samples, disable=not verbose): + for i in range(batch.shape[0]): + batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) + clvp_res = clvp( + text_tokens.repeat(batch.shape[0], 1), + batch, + return_loss=False, + ) + clip_results.append(clvp_res) + + clip_results = torch.cat(clip_results, dim=0) + samples = torch.cat(samples, dim=0) + best_results = samples[torch.topk(clip_results, k=k).indices] + del samples + + # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning + # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these + # results, but will increase memory usage. + with self.temporary_cuda(self.autoregressive) as autoregressive: + best_latents = autoregressive( + auto_conditioning.repeat(k, 1), + text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), + best_results, + torch.tensor( + [best_results.shape[-1] * self.autoregressive.mel_length_compression], + device=text_tokens.device, + ), + return_latent=True, + clip_inputs=False, + ) + del auto_conditioning + + if verbose: + print("Transforming autoregressive outputs into audio..") + wav_candidates = [] + for b in range(best_results.shape[0]): + codes = best_results[b].unsqueeze(0) + latents = best_latents[b].unsqueeze(0) + + # Find the first occurrence of the "calm" token and trim the codes to that. + ctokens = 0 + for code in range(codes.shape[-1]): + if codes[0, code] == calm_token: + ctokens += 1 + else: + ctokens = 0 + if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. + latents = latents[:, :code] + break + with self.temporary_cuda(self.diffusion) as diffusion: + mel = do_spectrogram_diffusion( + diffusion, + diffuser, + latents, + diffusion_conditioning, + temperature=diffusion_temperature, + verbose=verbose, + ) + with self.temporary_cuda(self.vocoder) as vocoder: + wav = vocoder.inference(mel) + wav_candidates.append(wav.cpu()) + + def potentially_redact(clip, text): + if self.enable_redaction: + return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1) + return clip + + wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates] + + if len(wav_candidates) > 1: + res = wav_candidates + else: + res = wav_candidates[0] + + return_dict = { + "wav": res, + "deterministic_seed": None, + "text": None, + "voice_samples": None, + "conditioning_latents": None, + } + if return_deterministic_state: + return_dict = { + "wav": res, + "deterministic_seed": deterministic_seed, + "text": text, + "voice_samples": voice_samples, + "conditioning_latents": conditioning_latents, + } + return return_dict + + def forward(self): + raise NotImplementedError("Tortoise Training is not implemented") + + def eval_step(self): + raise NotImplementedError("Tortoise Training is not implemented") + + @staticmethod + def init_from_config(config: "TortoiseConfig", **kwargs): # pylint: disable=unused-argument + return Tortoise(config) + + def load_checkpoint( + self, + config, + checkpoint_dir, + ar_checkpoint_path=None, + diff_checkpoint_path=None, + clvp_checkpoint_path=None, + vocoder_checkpoint_path=None, + eval=False, + strict=True, + **kwargs, + ): # pylint: disable=unused-argument, redefined-builtin + """Load a model checkpoints from a directory. This model is with multiple checkpoint files and it + expects to have all the files to be under the given `checkpoint_dir` with the rigth names. + If eval is True, set the model to eval mode. + + Args: + config (TortoiseConfig): The model config. + checkpoint_dir (str): The directory where the checkpoints are stored. + ar_checkpoint_path (str, optional): The path to the autoregressive checkpoint. Defaults to None. + diff_checkpoint_path (str, optional): The path to the diffusion checkpoint. Defaults to None. + clvp_checkpoint_path (str, optional): The path to the CLVP checkpoint. Defaults to None. + vocoder_checkpoint_path (str, optional): The path to the vocoder checkpoint. Defaults to None. + eval (bool, optional): Whether to set the model to eval mode. Defaults to False. + strict (bool, optional): Whether to load the model strictly. Defaults to True. + """ + if self.models_dir is None: + self.models_dir = checkpoint_dir + ar_path = ar_checkpoint_path or os.path.join(checkpoint_dir, "autoregressive.pth") + diff_path = diff_checkpoint_path or os.path.join(checkpoint_dir, "diffusion_decoder.pth") + clvp_path = clvp_checkpoint_path or os.path.join(checkpoint_dir, "clvp2.pth") + vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth") + + if os.path.exists(ar_path): + self.autoregressive.load_state_dict(torch.load(ar_path), strict=strict) + + if os.path.exists(diff_path): + self.diffusion.load_state_dict(torch.load(diff_path), strict=strict) + + if os.path.exists(clvp_path): + self.clvp.load_state_dict(torch.load(clvp_path), strict=strict) + + if os.path.exists(vocoder_checkpoint_path): + self.vocoder.load_state_dict( + config.model_args.vocoder.value.optionally_index( + torch.load( + vocoder_checkpoint_path, + map_location=torch.device("cpu"), + ) + ) + ) + + if eval: + self.autoregressive.post_init_gpt2_config(self.args.kv_cache) + self.autoregressive.eval() + self.diffusion.eval() + self.clvp.eval() + self.vocoder.eval() + + def train_step(self): + raise NotImplementedError("Tortoise Training is not implemented") diff --git a/TTS/tts/utils/assets/tortoise/tokenizer.json b/TTS/tts/utils/assets/tortoise/tokenizer.json new file mode 100644 index 0000000000..a128f27305 --- /dev/null +++ b/TTS/tts/utils/assets/tortoise/tokenizer.json @@ -0,0 +1 @@ +{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}} \ No newline at end of file diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py index d7eed7052c..fd40ebb048 100644 --- a/TTS/utils/audio/torch_transforms.py +++ b/TTS/utils/audio/torch_transforms.py @@ -78,6 +78,7 @@ def __init__( power=None, use_htk=False, mel_norm="slaney", + normalized=False, ): super().__init__() self.n_fft = n_fft @@ -96,6 +97,7 @@ def __init__( self.mel_norm = mel_norm self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) self.mel_basis = None + self.normalized = normalized if use_mel: self._build_mel_basis() @@ -125,7 +127,7 @@ def __call__(self, x): self.window, center=True, pad_mode="reflect", # compatible with audio.py - normalized=False, + normalized=self.normalized, onesided=True, return_complex=False, ) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 8bf13bccd9..0d0b90648e 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -272,10 +272,16 @@ def download_model(self, model_name): os.makedirs(output_path, exist_ok=True) print(f" > Downloading model to {output_path}") # download from github release - self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) - self.print_model_license(model_item=model_item) + if isinstance(model_item["github_rls_url"], list): + self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) + else: + self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) + self.print_model_license(model_item=model_item) # find downloaded files - output_model_path, output_config_path = self._find_files(output_path) + output_model_path = output_path + output_config_path = None + if model != "tortoise-v2": + output_model_path, output_config_path = self._find_files(output_path) # update paths in the config.json self._update_paths(output_path, output_config_path) return output_model_path, output_config_path, model_item @@ -415,6 +421,25 @@ def _download_zip_file(file_url, output_folder, progress_bar): # remove the extracted folder rmtree(os.path.join(output_folder, z.namelist()[0])) + @staticmethod + def _download_model_files(file_urls, output_folder, progress_bar): + """Download the github releases""" + for file_url in file_urls: + # download the file + r = requests.get(file_url, stream=True) + # extract the file + bease_filename = file_url.split("/")[-1] + temp_zip_name = os.path.join(output_folder, bease_filename) + total_size_in_bytes = int(r.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + with open(temp_zip_name, "wb") as file: + if progress_bar: + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + for data in r.iter_content(block_size): + if progress_bar: + progress_bar.update(len(data)) + file.write(data) + @staticmethod def _check_dict_key(my_dict, key): if key in my_dict.keys() and my_dict[key] is not None: diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 74e5c9ecf5..1b91521b13 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -1,3 +1,4 @@ +import os import time from typing import List @@ -31,6 +32,8 @@ def __init__( encoder_config: str = "", vc_checkpoint: str = "", vc_config: str = "", + model_dir: str = "", + voice_dir: str = None, use_cuda: bool = False, ) -> None: """General 🐸 TTS interface for inference. It takes a tts and a vocoder @@ -78,7 +81,7 @@ def __init__( self.d_vector_dim = 0 self.seg = self._get_segmenter("en") self.use_cuda = use_cuda - + self.voice_dir = voice_dir if self.use_cuda: assert torch.cuda.is_available(), "CUDA is not availabe on this machine." @@ -94,6 +97,10 @@ def __init__( self._load_vc(vc_checkpoint, vc_config, use_cuda) self.output_sample_rate = self.vc_config.audio["output_sample_rate"] + if model_dir: + self._load_tts_from_dir(model_dir, use_cuda) + self.output_sample_rate = self.tts_config.audio["output_sample_rate"] + @staticmethod def _get_segmenter(lang: str): """get the sentence segmenter for the given language. @@ -126,6 +133,19 @@ def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> N if use_cuda: self.vc_model.cuda() + def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: + """Load the TTS model from a directory. + + We assume the model knows how to load itself from the directory and there is a config.json file in the directory. + """ + + config = load_config(os.path.join(model_dir, "config.json")) + self.tts_config = config + self.tts_model = setup_tts_model(config) + self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True) + if use_cuda: + self.tts_model.cuda() + def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None: """Load the TTS model. @@ -220,6 +240,7 @@ def tts( style_text=None, reference_wav=None, reference_speaker_name=None, + **kwargs, ) -> List[int]: """🐸 TTS magic. Run all the models and generate speech. @@ -249,6 +270,9 @@ def tts( print(sens) # handle multi-speaker + if "voice_dir" in kwargs: + self.voice_dir = kwargs["voice_dir"] + kwargs.pop("voice_dir") speaker_embedding = None speaker_id = None if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"): @@ -275,7 +299,7 @@ def tts( else: speaker_embedding = None else: - if speaker_name: + if speaker_name and self.voice_dir is None: raise ValueError( f" [!] Missing speakers.json file path for selecting speaker {speaker_name}." "Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. " @@ -312,29 +336,39 @@ def tts( ) # compute a new d_vector from the given clip. - if speaker_wav is not None: + if speaker_wav is not None and self.tts_model.speaker_manager is not None: speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) use_gl = self.vocoder_model is None if not reference_wav: for sen in sens: - # synthesize voice - outputs = synthesis( - model=self.tts_model, - text=sen, - CONFIG=self.tts_config, - use_cuda=self.use_cuda, - speaker_id=speaker_id, - style_wav=style_wav, - style_text=style_text, - use_griffin_lim=use_gl, - d_vector=speaker_embedding, - language_id=language_id, - ) + if hasattr(self.tts_model, "synthesize"): + sp_name = "random" if speaker_name is None else speaker_name + outputs = self.tts_model.synthesize( + text=sen, + config=self.tts_config, + speaker_id=sp_name, + extra_voice_dirs=self.voice_dir, + **kwargs, + ) + else: + # synthesize voice + outputs = synthesis( + model=self.tts_model, + text=sen, + CONFIG=self.tts_config, + use_cuda=self.use_cuda, + speaker_id=speaker_id, + style_wav=style_wav, + style_text=style_text, + use_griffin_lim=use_gl, + d_vector=speaker_embedding, + language_id=language_id, + ) waveform = outputs["wav"] - mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() if not use_gl: + mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() # denormalize tts output based on tts audio config mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T device_type = "cuda" if self.use_cuda else "cpu" diff --git a/docs/source/index.md b/docs/source/index.md index ccc7f66d41..9c9f922f17 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -51,6 +51,7 @@ models/forward_tts.md models/tacotron1-2.md models/overflow.md + model/tortoise.md .. toctree:: :maxdepth: 2 diff --git a/docs/source/models/tortoise.md b/docs/source/models/tortoise.md new file mode 100644 index 0000000000..c49a0fcb39 --- /dev/null +++ b/docs/source/models/tortoise.md @@ -0,0 +1,94 @@ +# Tortoise 🐢 +Tortoise is a very expressive TTS system with impressive voice cloning capabilities. It is based on an GPT like autogressive acoustic model that converts input +text to discritized acouistic tokens, a diffusion model that converts these tokens to melspeectrogram frames and a Univnet vocoder to convert the spectrograms to +the final audio signal. The important downside is that Tortoise is very slow compared to the parallel TTS models like VITS. + +Big thanks to 👑[@manmay-nakhashi](https://github.com/manmay-nakhashi) who helped us implement Tortoise in 🐸TTS. + +Example use: + +```python +from TTS.tts.configs.tortoise_config import TortoiseConfig +from TTS.tts.models.tortoise import Tortoise + +config = TortoiseConfig() +model = Tortoise.inif_from_config(config) +model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True) + +# with random speaker +output_dict = model.synthesize(text, config, speaker_id="random", extra_voice_dirs=None, **kwargs) + +# cloning a speaker +output_dict = model.synthesize(text, config, speaker_id="speaker_n", extra_voice_dirs="path/to/speaker_n/", **kwargs) +``` + +Using 🐸TTS API: + +```python +from TTS.api import TTS +tts = TTS("tts_models/en/multi-dataset/tortoise-v2") + +# cloning `lj` voice from `TTS/tts/utils/assets/tortoise/voices/lj` +# with custom inference settings overriding defaults. +tts.tts_to_file(text="Hello, my name is Manmay , how are you?", + file_path="output.wav", + voice_dir="TTS/tts/utils/assets/tortoise/voices/", + speaker="lj", + num_autoregressive_samples=1, + diffusion_iterations=10) + +# Using presets with the same voice +tts.tts_to_file(text="Hello, my name is Manmay , how are you?", + file_path="output.wav", + voice_dir="TTS/tts/utils/assets/tortoise/voices/", + speaker="lj", + preset="ultra_fast") + +# Random voice generation +tts.tts_to_file(text="Hello, my name is Manmay , how are you?", + file_path="output.wav") +``` + +Using 🐸TTS Command line: + +```console +# cloning the `lj` voice +tts --model_name tts_models/en/multi-dataset/tortoise-v2 \ +--text "This is an example." \ +--out_path "/data/speech_synth/coqui-tts/TTS/tests/outputs/output.wav" \ +--voice_dir TTS/tts/utils/assets/tortoise/voices/ \ +--speaker_idx "lj" \ +--progress_bar True + +# Random voice generation +tts --model_name tts_models/en/multi-dataset/tortoise-v2 \ +--text "This is an example." \ +--out_path "/data/speech_synth/coqui-tts/TTS/tests/outputs/output.wav" \ +--progress_bar True +``` + + +## Important resources & papers +- Original Repo: https://github.com/neonbjb/tortoise-tts +- Faster implementation: https://github.com/152334H/tortoise-tts-fast +- Univnet: https://arxiv.org/abs/2106.07889 +- Latent Diffusion:https://arxiv.org/abs/2112.10752 +- DALL-E: https://arxiv.org/abs/2102.12092 + +## TortoiseConfig +```{eval-rst} +.. autoclass:: TTS.tts.configs.tortoise_config.TortoiseConfig + :members: +``` + +## TortoiseArgs +```{eval-rst} +.. autoclass:: TTS.tts.models.tortoise.TortoiseArgs + :members: +``` + +## Tortoise Model +```{eval-rst} +.. autoclass:: TTS.tts.models.tortoise.Tortoise + :members: +``` diff --git a/notebooks/Tortoise.ipynb b/notebooks/Tortoise.ipynb new file mode 100644 index 0000000000..788d99e0cf --- /dev/null +++ b/notebooks/Tortoise.ipynb @@ -0,0 +1,108 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "4d50310e-f094-42e0-af30-1e42b13ceb95", + "metadata": {}, + "outputs": [], + "source": [ + "#@title # Setup\n", + "# Imports used through the rest of the notebook.\n", + "import torch\n", + "import torchaudio\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import IPython\n", + "\n", + "from TTS.tts.models.tortoise import TextToSpeech\n", + "from TTS.tts.layers.tortoise.audio_utils import load_audio, load_voice, load_voices\n", + "\n", + "# This will download all the models used by Tortoise from the HuggingFace hub.\n", + "tts = TextToSpeech()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e126c3c3-d90a-492f-b5bb-0d86587f15cc", + "metadata": {}, + "outputs": [], + "source": [ + "# This is the text that will be spoken.\n", + "text = \"Joining two modalities results in a surprising increase in generalization! What would happen if we combined them all?\" #@param {type:\"string\"}\n", + "#@markdown Show code for multiline text input\n", + "# Here's something for the poetically inclined.. (set text=)\n", + "\"\"\"\n", + "Then took the other, as just as fair,\n", + "And having perhaps the better claim,\n", + "Because it was grassy and wanted wear;\n", + "Though as for that the passing there\n", + "Had worn them really about the same,\"\"\"\n", + "\n", + "# Pick a \"preset mode\" to determine quality. Options: {\"ultra_fast\", \"fast\" (default), \"standard\", \"high_quality\"}. See docs in api.py\n", + "preset = \"fast\" #@param [\"ultra_fast\", \"fast\", \"standard\", \"high_quality\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9413f553-5bd0-4820-bad4-edd7fd7d2370", + "metadata": {}, + "outputs": [], + "source": [ + "%ls ../TTS/tts/utils/assets/tortoise/voices/\n", + "import IPython\n", + "IPython.display.Audio(filename='../TTS/tts/utils/assets/tortoise/voices/lj/1.wav')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96a98ae5-313b-40d1-9311-5a785f2c9a4e", + "metadata": {}, + "outputs": [], + "source": [ + "#@markdown Pick one of the voices from the output above\n", + "voice = 'lj' #@param {type:\"string\"}\n", + "\n", + "#@markdown Load it and send it through Tortoise.\n", + "voice_samples, conditioning_latents = load_voice(voice)\n", + "gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents, \n", + " preset=preset)\n", + "torchaudio.save('generated.wav', gen.squeeze(0).cpu(), 24000)\n", + "IPython.display.Audio('generated.wav')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04e473e5-c489-4a78-aa11-03e89a778ed8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/requirements.txt b/requirements.txt index ee4f2677fc..57640b6f02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,3 +45,8 @@ g2pkk>=0.1.1 bangla==0.0.2 bnnumerizer bnunicodenormalizer==0.1.1 + +#deps for tortoise +k_diffusion +einops +transformers \ No newline at end of file diff --git a/tests/inference_tests/test_python_api.py b/tests/inference_tests/test_python_api.py index 91648c796e..f8ee4505d4 100644 --- a/tests/inference_tests/test_python_api.py +++ b/tests/inference_tests/test_python_api.py @@ -12,6 +12,7 @@ if is_coqui_available: + class CS_APITest(unittest.TestCase): def test_speakers(self): tts = CS_API() @@ -40,7 +41,6 @@ def test_tts(self): self.assertEqual(sr, 44100) self.assertGreater(len(wav), 1) - class TTSTest(unittest.TestCase): def test_single_speaker_model(self): tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False) @@ -86,7 +86,9 @@ def test_studio_model(self): def test_multi_speaker_multi_lingual_model(self): tts = TTS() tts.load_tts_model_by_name(tts.models[0]) # YourTTS - tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path=OUTPUT_PATH) + tts.tts_to_file( + text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path=OUTPUT_PATH + ) self.assertTrue(tts.is_multi_speaker) self.assertTrue(tts.is_multi_lingual)