diff --git a/ChatTTS/model/velocity/llm.py b/ChatTTS/model/velocity/llm.py index b473b562c..a37f5cb34 100644 --- a/ChatTTS/model/velocity/llm.py +++ b/ChatTTS/model/velocity/llm.py @@ -2,12 +2,12 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from vllm.utils import Counter from .configs import EngineArgs from .llm_engine import LLMEngine from .output import RequestOutput from .sampling_params import SamplingParams -from vllm.utils import Counter class LLM: diff --git a/ChatTTS/model/velocity/worker.py b/ChatTTS/model/velocity/worker.py index 90aca7f32..294c77d37 100644 --- a/ChatTTS/model/velocity/worker.py +++ b/ChatTTS/model/velocity/worker.py @@ -12,6 +12,7 @@ from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine + from .model_runner import ModelRunner diff --git a/examples/cmd/run.py b/examples/cmd/run.py index 9249eb5b9..389890e80 100644 --- a/examples/cmd/run.py +++ b/examples/cmd/run.py @@ -1,22 +1,23 @@ +import os, sys + +if sys.platform == "darwin": + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +now_dir = os.getcwd() +sys.path.append(now_dir) + from typing import Optional, List import argparse -import os -import sys import numpy as np import ChatTTS + from tools.logger import get_logger from tools.audio import pcm_arr_to_mp3_view from tools.normalizer.en import normalizer_en_nemo_text from tools.normalizer.zh import normalizer_zh_tn -if sys.platform == "darwin": - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - -now_dir = os.getcwd() -sys.path.append(now_dir) - logger = get_logger("Command") diff --git a/examples/web/funcs.py b/examples/web/funcs.py index 933fb6833..b8fcd4f4a 100644 --- a/examples/web/funcs.py +++ b/examples/web/funcs.py @@ -1,4 +1,3 @@ -import sys import random from typing import Optional from time import sleep