Skip to content

Commit

Permalink
fixed streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
mcantillon21 committed Aug 15, 2023
1 parent 94b8621 commit aff17b5
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
def create_chain(
retriever: BaseRetriever,
model_provider: Union[Literal["openai"], Literal["anthropic"]],
chat_history: Optional[list] = None,
model: Optional[str] = None,
temperature: float = 0.0,
) -> Runnable:
Expand All @@ -65,7 +64,7 @@ def create_chain(
Conversation History:
<conversation_history>
{history}
{history}
<conversation_history/>
Answer the user's question to the best of your ability: {question}
Expand All @@ -74,21 +73,13 @@ def create_chain(
prompt = PromptTemplate(
input_variables=["history", "context", "question"], template=_template
)
memory = ConversationBufferMemory(input_key="question", memory_key="history")
chat_history_ = chat_history or []
for message in chat_history_:
memory.save_context(
{"question": message["question"]}, {"result": message["result"]}
)

qa_chain = RetrievalQA.from_chain_type(
model,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": prompt, "memory": memory},

chain = (
prompt
| model
)
return qa_chain

return chain


def _get_retriever():
Expand All @@ -100,7 +91,7 @@ def _get_retriever():
url=WEAVIATE_URL,
auth_client_secret=weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY),
)
print(client.query.aggregate("LangChain_idx").with_meta_count().do())
# print(client.query.aggregate("LangChain_idx").with_meta_count().do())
weaviate_client = Weaviate(
client=client,
index_name="LangChain_idx",
Expand All @@ -111,6 +102,13 @@ def _get_retriever():
)
return weaviate_client.as_retriever(search_kwargs=dict(k=10))

def _process_chat_history(chat_history):
for chat in chat_history:
if 'question' in chat:
chat['HumanChatMessage'] = chat.pop('question')
if 'result' in chat:
chat['AIChatMessage'] = chat.pop('result')
return chat_history

@app.post("/chat")
async def chat_endpoint(request: Request):
Expand All @@ -125,18 +123,23 @@ async def chat_endpoint(request: Request):
chat_history = data.get("history", [])

retriever = _get_retriever()
source_docs = retriever.invoke(question) # opportunity to return source documents
context = [doc.page_content for doc in source_docs]

chat_history = _process_chat_history(chat_history)

qa_chain = create_chain(
retriever=retriever, model_provider=model_type, chat_history=chat_history
retriever=retriever, model_provider=model_type
)
print("Recieved question: ", question)

async def stream():
result = ""
try:
for s in qa_chain.stream(question, config=runnable_config):
result += s["result"]
yield s["result"]
await asyncio.sleep(0)
async for s in qa_chain.astream({"context": context, "question": question, "history": chat_history}, config=runnable_config):
print(s.content, end="", flush=True)
result += s.content
yield s.content

except Exception as e:
logging.error(e)
Expand Down

0 comments on commit aff17b5

Please sign in to comment.