Skip to content

Commit

Permalink
feat(core): dvae&vocos switch to safetensors
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Oct 15, 2024
1 parent b3d511b commit b9b007e
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 56 deletions.
6 changes: 3 additions & 3 deletions ChatTTS/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

@dataclass(repr=False, eq=False)
class Path:
vocos_ckpt_path: str = "asset/Vocos.pt"
dvae_ckpt_path: str = "asset/DVAE_full.pt"
vocos_ckpt_path: str = "asset/Vocos.safetensors"
dvae_ckpt_path: str = "asset/DVAE.safetensors"
gpt_ckpt_path: str = "asset/gpt"
decoder_ckpt_path: str = "asset/Decoder.pt"
decoder_ckpt_path: str = "asset/Decoder.safetensors"
tokenizer_path: str = "asset/tokenizer"
embed_path: str = "asset/Embed.safetensors"

Expand Down
51 changes: 21 additions & 30 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .config import Config
from .model import DVAE, Embed, GPT, gen_logits, Tokenizer, Speaker
from .utils import (
load_safetensors,
check_all_assets,
download_all_assets,
select_device,
Expand Down Expand Up @@ -97,7 +98,7 @@ def download_models(
try:
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
allow_patterns=["*.pt", "*.yaml", "*.json", "*.safetensors"],
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
)
except:
download_path = None
Expand Down Expand Up @@ -263,26 +264,22 @@ def _load(
.eval()
)
assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True))
vocos.load_state_dict(load_safetensors(vocos_ckpt_path))
self.vocos = vocos
self.logger.log(logging.INFO, "vocos loaded.")

dvae = (
DVAE(
decoder_config=asdict(self.config.dvae.decoder),
encoder_config=asdict(self.config.dvae.encoder),
vq_config=asdict(self.config.dvae.vq),
dim=self.config.dvae.decoder.idim,
coef=coef,
device=device,
)
.to(device)
.eval()
dvae = DVAE(
decoder_config=asdict(self.config.dvae.decoder),
encoder_config=asdict(self.config.dvae.encoder),
vq_config=asdict(self.config.dvae.vq),
dim=self.config.dvae.decoder.idim,
coef=coef,
device=device,
)
coef = str(dvae)
assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True))
self.dvae = dvae
dvae.load_pretrained(dvae_ckpt_path, device)
self.dvae = dvae.eval()
self.logger.log(logging.INFO, "dvae loaded.")

embed = Embed(
Expand All @@ -291,7 +288,7 @@ def _load(
self.config.embed.num_text_tokens,
self.config.embed.num_vq,
)
embed.from_pretrained(embed_path, device=device)
embed.load_pretrained(embed_path, device=device)
self.embed = embed.to(device)
self.logger.log(logging.INFO, "embed loaded.")

Expand All @@ -305,7 +302,7 @@ def _load(
logger=self.logger,
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.from_pretrained(gpt_ckpt_path, embed_path, experimental=experimental)
gpt.load_pretrained(gpt_ckpt_path, embed_path, experimental=experimental)
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt
self.logger.log(logging.INFO, "gpt loaded.")
Expand All @@ -315,22 +312,16 @@ def _load(
)
self.logger.log(logging.INFO, "speaker loaded.")

decoder = (
DVAE(
decoder_config=asdict(self.config.decoder),
dim=self.config.decoder.idim,
coef=coef,
device=device,
)
.to(device)
.eval()
decoder = DVAE(
decoder_config=asdict(self.config.decoder),
dim=self.config.decoder.idim,
coef=coef,
device=device,
)
coef = str(decoder)
assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
decoder.load_state_dict(
torch.load(decoder_ckpt_path, weights_only=True, mmap=True)
)
self.decoder = decoder
decoder.load_pretrained(decoder_ckpt_path, device)
self.decoder = decoder.eval()
self.logger.log(logging.INFO, "decoder loaded.")

if tokenizer_path:
Expand Down
9 changes: 8 additions & 1 deletion ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import pybase16384 as b14
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from vector_quantize_pytorch import GroupedResidualFSQ

from ..utils import load_safetensors


class ConvNeXtBlock(nn.Module):
def __init__(
Expand Down Expand Up @@ -250,6 +251,12 @@ def __call__(
self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
) -> torch.Tensor:
return super().__call__(inp, mode)

@torch.inference_mode()
def load_pretrained(self, filename: str, device: torch.device):
state_dict_tensors = load_safetensors(filename)
self.load_state_dict(state_dict_tensors)
self.to(device)

@torch.inference_mode()
def forward(
Expand Down
10 changes: 4 additions & 6 deletions ChatTTS/model/embed.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from safetensors.torch import safe_open
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm

from ..utils import load_safetensors


class Embed(nn.Module):
def __init__(
Expand Down Expand Up @@ -34,11 +35,8 @@ def __init__(
)

@torch.inference_mode()
def from_pretrained(self, filename: str, device: torch.device):
state_dict_tensors = {}
with safe_open(filename, framework="pt") as f:
for k in f.keys():
state_dict_tensors[k] = f.get_tensor(k)
def load_pretrained(self, filename: str, device: torch.device):
state_dict_tensors = load_safetensors(filename)
self.load_state_dict(state_dict_tensors)
self.to(device)

Expand Down
2 changes: 1 addition & 1 deletion ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.head_text = embed.head_text.__call__
self.head_code = [hc.__call__ for hc in embed.head_code]

def from_pretrained(
def load_pretrained(
self, gpt_folder: str, embed_file_path: str, experimental=False
):
if self.is_vllm and platform.system().lower() == "linux":
Expand Down
8 changes: 4 additions & 4 deletions ChatTTS/res/sha256_map.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"sha256_asset_Decoder_pt" : "9964e36e840f0e3a748c5f716fe6de6490d2135a5f5155f4a642d51860e2ec38",
"sha256_asset_DVAE_full_pt" : "553eb75763511e23f3e5f86303e2163c5ca775489d637fb635d979c8ae58bbe5",
"sha256_asset_Embed_safetensors" : "2ff0be7134934155741b643b74e32fb6bf3eec41257984459b2ed60cdb4c48b0",
"sha256_asset_Vocos_pt" : "09a670eda1c08b740013679c7a90ebb7f1a97646ea7673069a6838e6b51d6c58",
"sha256_asset_Decoder_safetensors": "77aa55e0a977949c4733df3c6f876fa85860d3298cba63295a7bc6901729d4e0",
"sha256_asset_DVAE_safetensors" : "1d0b044a8368c0513100a2eca98456b289e6be6a18b7a63be1bcaa315ea874d9",
"sha256_asset_Embed_safetensors" : "2ff0be7134934155741b643b74e32fb6bf3eec41257984459b2ed60cdb4c48b0",
"sha256_asset_Vocos_safetensors" : "07e5561491cce41f7f90cfdb94b2ff263ff5742c3d89339db99b17ad82cc3f44",

"sha256_asset_gpt_config_json" : "0aaa1ecd96c49ad4f473459eb1982fa7ad79fa5de08cde2781bf6ad1f9a0c236",
"sha256_asset_gpt_model_safetensors" : "cd0806fd971f52f6a22c923ec64982b305e817bcc41ca83417fcf9141b984a0f",
Expand Down
2 changes: 1 addition & 1 deletion ChatTTS/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .dl import check_all_assets, download_all_assets
from .gpu import select_device
from .io import get_latest_modified_file, del_all
from .io import load_safetensors, get_latest_modified_file, del_all
from .log import logger
6 changes: 3 additions & 3 deletions ChatTTS/utils/dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def check_all_assets(base_dir: Path, sha256_map: Dict[str, str], update=False) -
base_dir,
"asset",
names=(
"Decoder.pt",
"DVAE_full.pt",
"Decoder.safetensors",
"DVAE.safetensors",
"Embed.safetensors",
"Vocos.pt",
"Vocos.safetensors",
),
sha256_map=sha256_map,
update=update,
Expand Down
11 changes: 11 additions & 0 deletions ChatTTS/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,20 @@
from typing import Union
from dataclasses import is_dataclass

from safetensors import safe_open
import torch

from .log import logger


@torch.inference_mode()
def load_safetensors(filename: str):
state_dict_tensors = {}
with safe_open(filename, framework="pt") as f:
for k in f.keys():
state_dict_tensors[k] = f.get_tensor(k)
return state_dict_tensors

def get_latest_modified_file(directory):

files = [os.path.join(directory, f) for f in os.listdir(directory)]
Expand Down
14 changes: 7 additions & 7 deletions tools/checksum/tmpl.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package main

var files = [...]string{
"asset/Decoder.pt",
"asset/DVAE_full.pt",
"asset/Decoder.safetensors",
"asset/DVAE.safetensors",
"asset/Embed.safetensors",
"asset/Vocos.pt",
"asset/Vocos.safetensors",

"asset/gpt/config.json",
"asset/gpt/model.safetensors",
Expand All @@ -15,10 +15,10 @@ var files = [...]string{
}

const jsontmpl = `{
"sha256_asset_Decoder_pt" : "%s",
"sha256_asset_DVAE_full_pt" : "%s",
"sha256_asset_Embed_safetensors" : "%s",
"sha256_asset_Vocos_pt" : "%s",
"sha256_asset_Decoder_safetensors": "%s",
"sha256_asset_DVAE_safetensors" : "%s",
"sha256_asset_Embed_safetensors" : "%s",
"sha256_asset_Vocos_safetensors" : "%s",
"sha256_asset_gpt_config_json" : "%s",
"sha256_asset_gpt_model_safetensors" : "%s",
Expand Down

0 comments on commit b9b007e

Please sign in to comment.