Skip to content

Commit

Permalink
- chatchat init 可以直接指定 Xinference API 地址、LLM 模型、Embedding 模型,一步完成知识库初…
Browse files Browse the repository at this point in the history
…始化 (chatchat-space#4425)

- 修复一些配置项错误
  • Loading branch information
liunux4odoo authored Jul 5, 2024
1 parent de834b4 commit 73623b9
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 95 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ set CHATCHAT_ROOT=/path/to/chatchat_data
2. 执行初始化
```shell
chatchat init
# 如果你已经启动了 Xinference 服务,可以直接指定 Xinference API 地址、LLM 模型、Embedding 模型,可以跳过第3、4步,直接按第5步启动服务:
# chatchat init -x http://127.0.0.1:9999/v1 -l qwen2-instruct -e bce -r
# chatchat start -a
# 具体查看 chatchat init --help
```
该命令会执行以下操作:
- 创建所有需要的数据目录
Expand Down
56 changes: 48 additions & 8 deletions libs/chatchat-server/chatchat/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing as t

from chatchat.startup import main as startup_main
from chatchat.init_database import main as kb_main, create_tables
from chatchat.init_database import main as kb_main, create_tables, folder2db
from chatchat.settings import Settings
from chatchat.utils import build_logger

Expand All @@ -18,20 +18,60 @@ def main():


@main.command("init", help="项目初始化")
# @click.option("-k", "--recreate-kb", "kb_names",
# show_default=True,
# default="samples",
# help="同时重建知识库。可以指定多个知识库名称,以 , 分隔。")
def init():
@click.option("-x", "--xinference-endpoint", "xf_endpoint",
help="指定Xinference API 服务地址。默认为 http://127.0.0.1:9997/v1")
@click.option("-l", "--llm-model",
help="指定默认 LLM 模型。默认为 glm4-chat")
@click.option("-e", "--embed-model",
help="指定默认 Embedding 模型。默认为 bge-large-zh-v1.5")
@click.option("-r", "--recreate-kb",
is_flag=True,
show_default=True,
default=False,
help="同时重建知识库(必须确保指定的 embed model 可用)。")
@click.option("-k", "--kb-names", "kb_names",
show_default=True,
default="samples",
help="要重建知识库的名称。可以指定多个知识库名称,以 , 分隔。")
def init(
xf_endpoint: str = "",
llm_model: str = "",
embed_model: str = "",
recreate_kb: bool = False,
kb_names: str = "",
):
Settings.set_auto_reload(False)
bs = Settings.basic_settings
kb_names = [x.strip() for x in kb_names.split(",")]
logger.info(f"开始初始化项目数据目录:{Settings.CHATCHAT_ROOT}")
Settings.basic_settings.make_dirs()
logger.info("创建所有数据目录:成功。")
shutil.copytree(bs.PACKAGE_ROOT / "data/knowledge_base/samples", Path(bs.KB_ROOT_PATH) / "samples", dirs_exist_ok=True)
logger.info("复制 samples 知识库:成功。")
logger.info("复制 samples 知识库文件:成功。")
create_tables()
logger.info("初始化知识库数据库:成功。")

if xf_endpoint:
Settings.model_settings.MODEL_PLATFORMS[0].api_base_url = xf_endpoint
if llm_model:
Settings.model_settings.DEFAULT_LLM_MODEL = llm_model
if embed_model:
Settings.model_settings.DEFAULT_EMBEDDING_MODEL = embed_model

Settings.createl_all_templates()
Settings.set_auto_reload(True)

logger.info("生成默认配置文件:成功。")
logger.warning("<red>请先修改 model_settings.yaml 配置正确的模型平台、LLM模型和Embed模型,然后执行 chatchat kb -r 初始化知识库。</red>")
logger.warning("<red>请先检查 model_settings.yaml 里模型平台、LLM模型和Embed模型信息正确</red>")

if recreate_kb:
folder2db(kb_names=kb_names,
mode="recreate_vs",
vs_type=Settings.kb_settings.DEFAULT_VS_TYPE,
embed_model=Settings.model_settings.DEFAULT_EMBEDDING_MODEL)
logger.success("所有初始化已完成,执行 chatchat start -a 启动服务。")
else:
logger.warning("执行 chatchat kb -r 初始化知识库,然后 chatchat start -a 启动服务。")


main.add_command(startup_main, "start")
Expand Down
32 changes: 23 additions & 9 deletions libs/chatchat-server/chatchat/pydantic_settings_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,18 @@ class BaseFileSettings(BaseSettings):
env_file_encoding="utf-8",
)

def model_post_init(self, __context: os.Any) -> None:
self._auto_reload = True
return super().model_post_init(__context)

@property
def auto_reload(self) -> bool:
return self._auto_reload

@auto_reload.setter
def auto_reload(self, val: bool):
self._auto_reload = val

@classmethod
def settings_customise_sources(
cls,
Expand All @@ -191,23 +203,24 @@ def settings_customise_sources(
) -> tuple[PydanticBaseSettingsSource, ...]:
return init_settings, env_settings, dotenv_settings, YamlConfigSettingsSource(settings_cls)

@classmethod
def create_template_file(
cls,
init_kwds: t.Dict = {},
self,
model_obj: BaseFileSettings=None,
dump_kwds: t.Dict={},
sub_comments: t.Dict[str, SubModelComment]={},
write_file: bool | str | Path = False,
file_format: t.Literal["yaml", "json"] = "yaml",
) -> str:
if model_obj is None:
model_obj = self
if file_format == "yaml":
template = YamlTemplate(cls(**init_kwds), dump_kwds=dump_kwds, sub_comments=sub_comments)
template = YamlTemplate(model_obj=model_obj, dump_kwds=dump_kwds, sub_comments=sub_comments)
return template.create_yaml_template(write_to=write_file)
else:
dump_kwds.setdefault("indent", 4)
data = cls(**init_kwds).model_dump_json(**dump_kwds)
data = model_obj.model_dump_json(**dump_kwds)
if write_file:
write_file = cls.model_config.get("json_file")
write_file = self.model_config.get("json_file")
with open(write_file, "w", encoding="utf-8") as fp:
fp.write(data)
return data
Expand All @@ -218,20 +231,21 @@ def _lazy_load_key(settings: BaseSettings):
for n in ["env_file", "json_file", "yaml_file", "toml_file"]:
key = None
if file := settings.model_config.get(n):
if os.path.isfile(file):
if os.path.isfile(file) and os.path.getsize(file) > 0:
key = int(os.path.getmtime(file))
keys.append(key)
return tuple(keys)


_T = t.TypeVar("_T", bound=BaseSettings)
_T = t.TypeVar("_T", bound=BaseFileSettings)

@cached(max_size=1, algorithm=CachingAlgorithmFlag.LRU, thread_safe=True, custom_key_maker=_lazy_load_key)
def _cached_settings(settings: _T) -> _T:
'''
the sesstings is cached, and refreshed when configuration files changed
'''
settings.__init__()
if settings.auto_reload:
settings.__init__()
return settings


Expand Down
34 changes: 17 additions & 17 deletions libs/chatchat-server/chatchat/server/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,23 @@ def create_models_from_config(configs, callbacks, stream, max_tokens):
configs = configs or Settings.model_settings.LLM_MODEL_CONFIG
models = {}
prompts = {}
for model_type, model_configs in configs.items():
for model_name, params in model_configs.items():
callbacks = callbacks if params.get("callbacks", False) else None
# 判断是否传入 max_tokens 的值, 如果传入就按传入的赋值(api 调用且赋值), 如果没有传入则按照初始化配置赋值(ui 调用或 api 调用未赋值)
max_tokens_value = max_tokens if max_tokens is not None else params.get("max_tokens", 1000)
model_instance = get_ChatOpenAI(
model_name=model_name,
temperature=params.get("temperature", 0.5),
max_tokens=max_tokens_value,
callbacks=callbacks,
streaming=stream,
local_wrap=True,
)
models[model_type] = model_instance
prompt_name = params.get("prompt_name", "default")
prompt_template = get_prompt_template(type=model_type, name=prompt_name)
prompts[model_type] = prompt_template
for model_type, params in configs.items():
model_name = params.get("model", "").strip() or Settings.model_settings.DEFAULT_LLM_MODEL
callbacks = callbacks if params.get("callbacks", False) else None
# 判断是否传入 max_tokens 的值, 如果传入就按传入的赋值(api 调用且赋值), 如果没有传入则按照初始化配置赋值(ui 调用或 api 调用未赋值)
max_tokens_value = max_tokens if max_tokens is not None else params.get("max_tokens", 1000)
model_instance = get_ChatOpenAI(
model_name=model_name,
temperature=params.get("temperature", 0.5),
max_tokens=max_tokens_value,
callbacks=callbacks,
streaming=stream,
local_wrap=True,
)
models[model_type] = model_instance
prompt_name = params.get("prompt_name", "default")
prompt_template = get_prompt_template(type=model_type, name=prompt_name)
prompts[model_type] = prompt_template
return models, prompts


Expand Down
2 changes: 1 addition & 1 deletion libs/chatchat-server/chatchat/server/chat/file_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def parse_file(file: UploadFile) -> dict:
def upload_temp_docs(
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
prev_id: str = Form(None, description="前知识库ID"),
chunk_size: int = Form(Settings.model_settings.LLM_MODEL_CONFIG, description="知识库中单段文本最大长度"),
chunk_size: int = Form(Settings.kb_settings.CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(Settings.kb_settings.OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(Settings.kb_settings.ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
) -> BaseResponse:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ class ESKBService(KBService):
def do_init(self):
self.kb_path = self.get_kb_path(self.kb_name)
self.index_name = os.path.split(self.kb_path)[-1]
self.IP = Settings.kb_settings.kbs_config[self.vs_type()]["host"]
self.PORT = Settings.kb_settings.kbs_config[self.vs_type()]["port"]
self.user = Settings.kb_settings.kbs_config[self.vs_type()].get("user", "")
self.password = Settings.kb_settings.kbs_config[self.vs_type()].get("password", "")
self.dims_length = Settings.kb_settings.kbs_config[self.vs_type()].get("dims_length", None)
kb_config = Settings.kb_settings.kbs_config[self.vs_type()]
self.IP = kb_config["host"]
self.PORT = kb_config["port"]
self.user = kb_config.get("user", "")
self.password = kb_config.get("password", "")
self.dims_length = kb_config.get("dims_length", None)
self.embeddings_model = get_Embeddings(self.embed_model)
try:
# ES python客户端连接(仅连接)
Expand Down
6 changes: 3 additions & 3 deletions libs/chatchat-server/chatchat/server/knowledge_base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def docs2texts(
docs: List[Document] = None,
zh_title_enhance: bool = Settings.kb_settings.ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = Settings.model_settings.DEFAULT_EMBEDDING_MODEL,
chunk_size: int = Settings.kb_settings.CHUNK_SIZE,
chunk_overlap: int = Settings.kb_settings.OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
Expand Down Expand Up @@ -382,7 +382,7 @@ def file2text(
self,
zh_title_enhance: bool = Settings.kb_settings.ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = Settings.model_settings.DEFAULT_EMBEDDING_MODEL,
chunk_size: int = Settings.kb_settings.CHUNK_SIZE,
chunk_overlap: int = Settings.kb_settings.OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
Expand Down Expand Up @@ -421,7 +421,7 @@ def files2docs_in_thread_file2docs(

def files2docs_in_thread(
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
chunk_size: int = Settings.model_settings.DEFAULT_EMBEDDING_MODEL,
chunk_size: int = Settings.kb_settings.CHUNK_SIZE,
chunk_overlap: int = Settings.kb_settings.OVERLAP_SIZE,
zh_title_enhance: bool = Settings.kb_settings.ZH_TITLE_ENHANCE,
) -> Generator:
Expand Down
Loading

0 comments on commit 73623b9

Please sign in to comment.