Skip to content

Commit

Permalink
feat(gpt): switch to safetensor model
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Aug 24, 2024
1 parent d6e1584 commit 8a503fd
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 99 deletions.
2 changes: 1 addition & 1 deletion ChatTTS/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class Path:
vocos_ckpt_path: str = "asset/Vocos.pt"
dvae_ckpt_path: str = "asset/DVAE_full.pt"
gpt_ckpt_path: str = "asset/GPT.pt"
gpt_ckpt_path: str = "asset/gpt"
decoder_ckpt_path: str = "asset/Decoder.pt"
tokenizer_path: str = "asset/tokenizer"
embed_path: str = "asset/Embed.safetensors"
Expand Down
51 changes: 21 additions & 30 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ def load(
use_flash_attn=use_flash_attn,
use_vllm=use_vllm,
experimental=experimental,
**{
k: os.path.join(download_path, v)
for k, v in asdict(self.config.path).items()
},
)

def unload(self):
Expand Down Expand Up @@ -225,12 +221,6 @@ def interrupt(self):
@torch.no_grad()
def _load(
self,
vocos_ckpt_path: str = None,
dvae_ckpt_path: str = None,
gpt_ckpt_path: str = None,
embed_path: str = None,
decoder_ckpt_path: str = None,
tokenizer_path: str = None,
device: Optional[torch.device] = None,
compile: bool = False,
coef: Optional[str] = None,
Expand Down Expand Up @@ -260,8 +250,8 @@ def _load(
)
.eval()
)
assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True))
assert self.config.path.vocos_ckpt_path, "vocos_ckpt_path should not be None"
vocos.load_state_dict(torch.load(self.config.path.vocos_ckpt_path, weights_only=True, mmap=True))
self.vocos = vocos
self.logger.log(logging.INFO, "vocos loaded.")

Expand All @@ -277,35 +267,36 @@ def _load(
.eval()
)
coef = str(dvae)
assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True))
assert self.config.path.dvae_ckpt_path, "dvae_ckpt_path should not be None"
dvae.load_state_dict(torch.load(self.config.path.dvae_ckpt_path, weights_only=True, mmap=True))
self.dvae = dvae
self.logger.log(logging.INFO, "dvae loaded.")

embed = Embed(
self.config.embed.hidden_size,
self.config.embed.num_audio_tokens,
self.config.embed.num_text_tokens,
self.config.embed.num_vq,
)
embed.from_pretrained(self.config.path.embed_path)
self.embed = embed
self.logger.log(logging.INFO, "embed loaded.")

gpt = GPT(
gpt_config=asdict(self.config.gpt),
embed=self.embed,
use_flash_attn=use_flash_attn,
use_vllm=use_vllm,
device=device,
device_gpt=self.device_gpt,
logger=self.logger,
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.from_pretrained(gpt_ckpt_path, experimental=experimental)
assert self.config.path.gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.from_pretrained(self.config.path.gpt_ckpt_path, self.config.path.embed_path, experimental=experimental)
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt
self.logger.log(logging.INFO, "gpt loaded.")

embed = Embed(
self.config.embed.hidden_size,
self.config.embed.num_audio_tokens,
self.config.embed.num_text_tokens,
self.config.embed.num_vq,
)
embed.from_pretrained(self.config.path.embed_path)
self.embed = embed
self.logger.log(logging.INFO, "embed loaded.")

self.speaker = Speaker(
self.config.gpt.hidden_size, self.config.spk_stat, device
)
Expand All @@ -321,15 +312,15 @@ def _load(
.eval()
)
coef = str(decoder)
assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
assert self.config.path.decoder_ckpt_path, "decoder_ckpt_path should not be None"
decoder.load_state_dict(
torch.load(decoder_ckpt_path, weights_only=True, mmap=True)
torch.load(self.config.path.decoder_ckpt_path, weights_only=True, mmap=True)
)
self.decoder = decoder
self.logger.log(logging.INFO, "decoder loaded.")

if tokenizer_path:
self.tokenizer = Tokenizer(tokenizer_path)
if self.config.path.tokenizer_path:
self.tokenizer = Tokenizer(self.config.path.tokenizer_path)
self.logger.log(logging.INFO, "tokenizer loaded.")

self.coef = coef
Expand Down
78 changes: 10 additions & 68 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,21 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as P
from torch.nn.utils.parametrizations import weight_norm
from tqdm import tqdm
from transformers import LlamaModel, LlamaConfig
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import is_flash_attn_2_available

from ..utils import del_all
from .embed import Embed


class GPT(nn.Module):
def __init__(
self,
gpt_config: dict,
embed: Embed,
use_flash_attn=False,
use_vllm=False,
device=torch.device("cpu"),
Expand Down Expand Up @@ -52,86 +53,27 @@ def __init__(

self.gpt, self.llama_config = self._build_llama(gpt_config, self.device_gpt)

self.model_dim = int(self.gpt.config.hidden_size)
self.emb_code = nn.ModuleList(
[
nn.Embedding(
self.num_audio_tokens,
self.model_dim,
device=self.device_gpt,
)
for _ in range(self.num_vq)
],
)
self.emb_text = nn.Embedding(
self.num_text_tokens, self.model_dim, device=self.device_gpt
)

self.head_text = weight_norm(
nn.Linear(
self.model_dim,
self.num_text_tokens,
bias=False,
device=device,
),
name="weight",
)
self.head_code = nn.ModuleList(
[
weight_norm(
nn.Linear(
self.model_dim,
self.num_audio_tokens,
bias=False,
device=device,
),
name="weight",
)
for _ in range(self.num_vq)
],
)
self.emb_code = [ec.__call__ for ec in embed.emb_code]
self.emb_text = embed.emb_text.__call__
self.head_text = embed.head_text.__call__
self.head_code = [hc.__call__ for hc in embed.head_code]

def from_pretrained(self, file_path: str, experimental=False):
def from_pretrained(self, gpt_folder: str, embed_file_path: str, experimental=False):
if self.is_vllm and platform.system().lower() == "linux":
from safetensors.torch import save_file

from .velocity import LLM, PostModel

vllm_folder = Path(os.getcwd()) / "asset" / "vllm"
if not os.path.exists(vllm_folder):
self.logger.info("initializing vLLM model to %s", str(vllm_folder))
vllm_folder.mkdir(mode=0o755, parents=True, exist_ok=True)
gpt = GPT(gpt_config=self.config)
gpt.from_pretrained(file_path)
gpt.gpt.save_pretrained(vllm_folder / "gpt")
post_model = (
PostModel(
int(gpt.gpt.config.hidden_size),
self.num_audio_tokens,
self.num_text_tokens,
)
.to(self.device)
.eval()
)
post_model.emb_code = gpt.emb_code
post_model.emb_text = gpt.emb_text
post_model.head_text = gpt.head_text
post_model.head_code = gpt.head_code
save_file(
post_model.state_dict(),
vllm_folder / "post_model.safetensors",
)
del post_model, gpt
self.llm = LLM(
model=str(vllm_folder / "gpt"),
model=gpt_folder,
num_audio_tokens=self.num_audio_tokens,
num_text_tokens=self.num_text_tokens,
post_model_path=vllm_folder / "post_model.safetensors",
post_model_path=embed_file_path,
)
self.logger.info("vLLM model loaded")
return

self.load_state_dict(torch.load(file_path, weights_only=True, mmap=True))
self.gpt: LlamaModel = LlamaModel.from_pretrained(gpt_folder)

if (
experimental
Expand Down

0 comments on commit 8a503fd

Please sign in to comment.