Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
yaojin3616 committed Dec 18, 2023
1 parent 28d7682 commit 7818e3e
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 33 deletions.
33 changes: 21 additions & 12 deletions src/backend/bisheng/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from bisheng.api.v1.schemas import StreamData
from bisheng.database.base import get_session
from bisheng.database.models.role_access import AccessType, RoleAccess
Expand All @@ -22,8 +21,8 @@ def remove_api_keys(flow: dict):
node_data = node.get('data').get('node')
template = node_data.get('template')
for value in template.values():
if (isinstance(value, dict) and has_api_terms(value['name']) and
value.get('password')):
if (isinstance(value, dict) and has_api_terms(value['name'])
and value.get('password')):
value['value'] = None

return flow
Expand All @@ -34,7 +33,8 @@ def build_input_keys_response(langchain_object, artifacts):

input_keys_response = {
'input_keys': {
key: '' for key in langchain_object.input_keys
key: ''
for key in langchain_object.input_keys
},
'memory_keys': [],
'handle_keys': artifacts.get('handle_keys', []),
Expand Down Expand Up @@ -72,7 +72,7 @@ def build_flow(graph_data: dict,
# Some error could happen when building the graph
graph = Graph.from_payload(graph_data)
except Exception as exc:
logger.exception(exc)
logger.error(exc)
error_message = str(exc)
yield str(StreamData(event='error', data={'error': error_message}))
return
Expand Down Expand Up @@ -224,7 +224,10 @@ def access_check(payload: dict, owner_user_id: int, target_id: int, type: Access
return True


def get_L2_param_from_flow(flow_data: dict, flow_id: str,):
def get_L2_param_from_flow(
flow_data: dict,
flow_id: str,
):
graph = Graph.from_payload(flow_data)
node_id = []
variable_ids = []
Expand All @@ -239,8 +242,10 @@ def get_L2_param_from_flow(flow_data: dict, flow_id: str,):
session: Session = next(get_session())
db_variables = session.exec(select(Variable).where(Variable.flow_id == flow_id)).all()

old_file_ids = {variable.node_id: variable
for variable in db_variables if variable.value_type == 3}
old_file_ids = {
variable.node_id: variable
for variable in db_variables if variable.value_type == 3
}
update = []
delete_node_ids = []
try:
Expand All @@ -252,12 +257,16 @@ def get_L2_param_from_flow(flow_data: dict, flow_id: str,):
old_file_ids.pop(id)
else:
# file type
db_new_var = Variable(flow_id=flow_id, node_id=id,
variable_name=file_name[index], value_type=3)
db_new_var = Variable(flow_id=flow_id,
node_id=id,
variable_name=file_name[index],
value_type=3)
update.append(db_new_var)
# delete variable which not delete by edit
old_variable_ids = {variable.node_id
for variable in db_variables if variable.value_type != 3}
old_variable_ids = {
variable.node_id
for variable in db_variables if variable.value_type != 3
}

if old_file_ids:
delete_node_ids.extend(list(old_file_ids.keys()))
Expand Down
2 changes: 1 addition & 1 deletion src/backend/bisheng/api/v1/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def addEmbedding(collection_name, knowledge_id: int, model: str, chunk_size: int
session.refresh(db_file)
callback_obj = db_file.copy()
except Exception as e:
logger.exception(e)
logger.error(e)
db_file = session.get(KnowledgeFile, knowledge_file.id)
setattr(db_file, 'status', 3)
setattr(db_file, 'remark', str(e)[:500])
Expand Down
1 change: 1 addition & 0 deletions src/backend/bisheng/api/v1/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def get_original_file(*, message_id: int, keys: str, session: Session = Depends(
chunk_res['score'] = round(match_score(chunk.chunk, keywords),
2) if len(keywords) > 0 else 0
chunk_res['file_id'] = chunk.file_id
chunk_res['source'] = file.file_name

result.append(chunk_res)

Expand Down
6 changes: 3 additions & 3 deletions src/backend/bisheng/interface/initialize/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def instantiate_chains(node_type, class_object: Type[Chain], params: Dict, id_di
params['get_chat_history'] = str
params['combine_docs_chain_kwargs'] = {
'prompt': params.pop('combine_docs_chain_kwargs', None),
'source_document': params.pop('source_document', None)
'document_prompt': params.pop('document_prompt', None)
}
params['combine_docs_chain_kwargs'] = {
k: v
Expand Down Expand Up @@ -440,13 +440,13 @@ def instantiate_embedding(class_object, params: Dict):


def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
search_kwargs = params.pop('search_kwargs', {})
user_name = params.pop('user_name', '')
search_kwargs = params.pop('search_kwargs', {})
if 'documents' not in params:
params['documents'] = []

if initializer := vecstore_initializer.get(class_object.__name__):
vecstore = initializer(class_object, params)
vecstore = initializer(class_object, params, search_kwargs)
else:
if 'texts' in params:
params['documents'] = params.pop('texts')
Expand Down
22 changes: 11 additions & 11 deletions src/backend/bisheng/interface/initialize/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def docs_in_params(params: dict) -> bool:
and params['texts'])


def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dict):
def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dict, search: dict):
"""Initialize mongodb and return the class object"""

MONGODB_ATLAS_CLUSTER_URI = params.pop('mongodb_atlas_cluster_uri')
Expand Down Expand Up @@ -59,7 +59,7 @@ def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dic
return class_object.from_documents(**params)


def initialize_supabase(class_object: Type[SupabaseVectorStore], params: dict):
def initialize_supabase(class_object: Type[SupabaseVectorStore], params: dict, search: dict):
"""Initialize supabase and return the class object"""
from supabase.client import Client, create_client

Expand All @@ -83,7 +83,7 @@ def initialize_supabase(class_object: Type[SupabaseVectorStore], params: dict):
return class_object.from_documents(client=supabase, **params)


def initialize_weaviate(class_object: Type[Weaviate], params: dict):
def initialize_weaviate(class_object: Type[Weaviate], params: dict, search: dict):
"""Initialize weaviate and return the class object"""
if not docs_in_params(params):
import weaviate # type: ignore
Expand All @@ -109,7 +109,7 @@ def initialize_weaviate(class_object: Type[Weaviate], params: dict):
return class_object.from_documents(**params)


def initialize_faiss(class_object: Type[FAISS], params: dict):
def initialize_faiss(class_object: Type[FAISS], params: dict, search: dict):
"""Initialize faiss and return the class object"""

if not docs_in_params(params):
Expand All @@ -122,7 +122,7 @@ def initialize_faiss(class_object: Type[FAISS], params: dict):
return faiss_index


def initialize_pinecone(class_object: Type[Pinecone], params: dict):
def initialize_pinecone(class_object: Type[Pinecone], params: dict, search: dict):
"""Initialize pinecone and return the class object"""

import pinecone # type: ignore
Expand Down Expand Up @@ -163,7 +163,7 @@ def initialize_pinecone(class_object: Type[Pinecone], params: dict):
return class_object.from_documents(**params)


def initialize_chroma(class_object: Type[Chroma], params: dict):
def initialize_chroma(class_object: Type[Chroma], params: dict, search: dict):
"""Initialize a ChromaDB object from the params"""
persist = params.pop('persist', False)
if not docs_in_params(params):
Expand All @@ -186,7 +186,7 @@ def initialize_chroma(class_object: Type[Chroma], params: dict):
return chromadb


def initialize_qdrant(class_object: Type[Qdrant], params: dict):
def initialize_qdrant(class_object: Type[Qdrant], params: dict, search: dict):
if not docs_in_params(params):
if 'location' not in params and 'api_key' not in params:
raise ValueError('Location and API key must be provided in the params')
Expand All @@ -207,7 +207,7 @@ def initialize_qdrant(class_object: Type[Qdrant], params: dict):
return class_object.from_documents(**params)


def initial_milvus(class_object: Type[Milvus], params: dict):
def initial_milvus(class_object: Type[Milvus], params: dict, search_kwargs: dict):
if not params['connection_args'] and settings.get_knowledge().get('vectorstores').get('Milvus'):
params['connection_args'] = settings.get_knowledge().get('vectorstores').get('Milvus').get(
'connection_args')
Expand All @@ -227,12 +227,12 @@ def initial_milvus(class_object: Type[Milvus], params: dict):
else:
embedding = HostEmbeddings(**model_param)
params['embedding'] = embedding
if knowledge.collection_name.startswith('partiton'):
params['partition_key'] = knowledge.id
if knowledge.collection_name.startswith('partition'):
search_kwargs.update({'partition_key': knowledge.id})
return class_object.from_documents(**params)


def initial_elastic(class_object: Type[ElasticKeywordsSearch], params: dict):
def initial_elastic(class_object: Type[ElasticKeywordsSearch], params: dict, search: dict):
if not params['elasticsearch_url'] and settings.get_knowledge().get('vectorstores').get(
'ElasticKeywordsSearch'):
params['elasticsearch_url'] = settings.get_knowledge().get('vectorstores').get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,7 @@ def add_texts(
# Insert into the collection.
try:
res: Collection
res = self.col.insert(insert_list,
partition_name=self._partition_field,
timeout=timeout,
**kwargs)
res = self.col.insert(insert_list, timeout=timeout, **kwargs)
pks.extend(res.primary_keys)
except MilvusException as e:
logger.error('Failed to insert batch starting at entity: %s/%s', i, total_count)
Expand Down Expand Up @@ -667,9 +664,9 @@ def similarity_search_with_score_by_vector(
if 'partition_key' in kwargs:
# add parttion
if expr:
expr = f"{expr} and {self._partition_field}==${kwargs['partition_key']}"
expr = f"{expr} and {self._partition_field}==\"{kwargs['partition_key']}\""
else:
expr = f"{self._partition_field}==${kwargs['partition_key']}"
expr = f"{self._partition_field}==\"{kwargs['partition_key']}\""

# Perform the search.
res = self.col.search(
Expand Down

0 comments on commit 7818e3e

Please sign in to comment.