forked from langchain-ai/chat-langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
233 lines (190 loc) · 8.71 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""Main entrypoint for the app."""
import os
from typing import Optional
import weaviate
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import MessagesPlaceholder
from langchain.schema.messages import HumanMessage, AIMessage, SystemMessage
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.vectorstores import Weaviate
from langsmith import Client
from threading import Thread
from queue import Queue, Empty
from collections.abc import Generator
from langchain.agents import (
Tool,
AgentExecutor,
)
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
import pickle
from langchain.callbacks.base import BaseCallbackHandler
from constants import WEAVIATE_SOURCES_INDEX_NAME, WEAVIATE_DOCS_INDEX_NAME
client = Client()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)
run_collector = RunCollectorCallbackHandler()
runnable_config = RunnableConfig(callbacks=[run_collector])
run_id = None
feedback_recorded = False
WEAVIATE_URL = os.environ["WEAVIATE_URL"]
WEAVIATE_API_KEY = os.environ["WEAVIATE_API_KEY"]
def search(inp: str, index_name: str, callbacks=None) -> str:
client = weaviate.Client(url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY))
weaviate_client = Weaviate(
client=client,
index_name=index_name,
text_key="text",
embedding=OpenAIEmbeddings(chunk_size=200),
by_text=False,
attributes=["source"] if not index_name == WEAVIATE_SOURCES_INDEX_NAME else None,
)
retriever = weaviate_client.as_retriever(search_kwargs=dict(k=3), callbacks=callbacks)
return retriever.get_relevant_documents(inp, callbacks=callbacks)
with open('agent_all_transformed.pkl', 'rb') as f:
all_texts = pickle.load(f)
def search_everything(inp: str, callbacks: Optional[any] = None ) -> str:
global all_texts
docs_references = search(inp, WEAVIATE_DOCS_INDEX_NAME, callbacks=callbacks)
# repo_references = search(inp, "WEAVIATE_REPO_INDEX_NAME", callbacks=callbacks)
all_references = docs_references
all_references_sources = [r for r in all_references if r.metadata['source']]
sources = search(inp, WEAVIATE_SOURCES_INDEX_NAME, callbacks=callbacks)
sources_docs = [doc for doc in all_texts if doc.metadata['source'] in [source.page_content for source in sources]]
combined_sources = sources_docs + all_references_sources
return [doc.page_content for doc in combined_sources]
def get_tools():
langchain_tool = Tool(
name="Documentation",
func=search_everything,
description="useful for when you need to refer to LangChain's documentation, for both API reference and codebase",
)
ALL_TOOLS = [langchain_tool]
return ALL_TOOLS
def get_agent(llm, chat_history: Optional[list] = None):
system_message = SystemMessage(
content=(
"You are an expert developer who is tasked with scouring documentation to answer question about LangChain. "
"Answer the following question as best you can. "
"Be inclined to include CORRECT Python code snippets if relevant to the question. If you can't find the answer, DO NOT hallucinate. Just say you don't know. "
"You have access to a LangChain knowledge bank retriever tool for your answer but know NOTHING about LangChain otherwise. "
"Always provide articulate detail to your action input. "
"You should always first check your search tool for information on the concepts in the question. "
"For example, given the following input question:\n"
"-----START OF EXAMPLE INPUT QUESTION-----\n"
"What is the transform() method for runnables? \n"
"-----END OF EXAMPLE INPUT QUESTION-----\n"
"Your research flow should be:\n"
"1. Query your search tool for information on 'Transform() method' to get as much context as you can about it. \n"
"2. Then, query your search tool for information on 'Runnables' to get as much context as you can about it. \n"
"3. Answer the question with the context you have gathered."
"For another example, given the following input question:\n"
"-----START OF EXAMPLE INPUT QUESTION-----\n"
"How can I use vLLM to run my own locally hosted model? \n"
"-----END OF EXAMPLE INPUT QUESTION-----\n"
"Your research flow should be:\n"
"1. Query your search tool for information on 'vLLM' to get as much context as you can about it. \n"
"2. Answer the question as you now have enough context."
))
prompt = OpenAIFunctionsAgent.create_prompt(
system_message=system_message,
extra_prompt_messages=[MessagesPlaceholder(variable_name="chat_history")],
)
memory = AgentTokenBufferMemory(memory_key="chat_history", llm=llm, max_token_limit=2000)
for msg in chat_history:
if "question" in msg:
memory.chat_memory.add_user_message(str(msg.pop("question")))
if "result" in msg:
memory.chat_memory.add_ai_message(str(msg.pop("result")))
tools = get_tools()
agent = OpenAIFunctionsAgent(
llm=llm, tools=tools, prompt=prompt
)
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
memory=memory,
verbose=True,
return_intermediate_steps=True,
)
return agent_executor
class QueueCallback(BaseCallbackHandler):
"""Callback handler for streaming LLM responses to a queue."""
# https://gist.github.com/mortymike/70711b028311681e5f3c6511031d5d43
def __init__(self, q):
self.q = q
def on_llm_new_token(self, token: str, **kwargs: any) -> None:
self.q.put(token)
def on_llm_end(self, *args, **kwargs: any) -> None:
return self.q.empty()
@app.post("/chat")
async def chat_endpoint(request: Request):
global run_id, feedback_recorded, trace_url
run_id = None
trace_url = None
feedback_recorded = False
run_collector.traced_runs = []
data = await request.json()
question = data.get("message")
chat_history = data.get("history", [])
conversation_id = data.get("conversation_id")
print("Recieved question: ", question)
def stream() -> Generator:
global run_id, trace_url, feedback_recorded
q = Queue()
job_done = object()
llm = ChatOpenAI(model="gpt-3.5-turbo-16k", streaming=True, temperature=0, callbacks=[QueueCallback(q)])
def task():
agent = get_agent(llm, chat_history)
agent.invoke({"input": question, "chat_history": chat_history}, config=runnable_config)
q.put(job_done)
t = Thread(target=task)
t.start()
content = ""
while True:
try:
next_token = q.get(True, timeout=1)
if next_token is job_done:
break
content += next_token
yield next_token
except Empty:
continue
if not run_id and run_collector.traced_runs:
run = run_collector.traced_runs[0]
run_id = run.id
return StreamingResponse(stream())
@app.post("/feedback")
async def send_feedback(request: Request):
global run_id, feedback_recorded
if feedback_recorded or run_id is None:
return {"result": "Feedback already recorded or no chat session found", "code": 400}
data = await request.json()
score = data.get("score")
client.create_feedback(run_id, "user_score", score=score)
feedback_recorded = True
return {"result": "posted feedback successfully", "code": 200}
trace_url = None
@app.post("/get_trace")
async def get_trace(request: Request):
global run_id, trace_url
if trace_url is None and run_id is not None:
trace_url = client.share_run(run_id)
if run_id is None:
return {"result": "No chat session found", "code": 400}
return trace_url if trace_url else {"result": "Trace URL not found", "code": 400}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)