Skip to content

Commit

Permalink
Pass source docs in directly to cohere (langchain-ai#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Feb 21, 2024
1 parent 0abf4b0 commit 24ffeb0
Showing 1 changed file with 63 additions and 29 deletions.
92 changes: 63 additions & 29 deletions chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langchain_community.chat_models import ChatAnthropic, ChatCohere, ChatFireworks
from langchain_community.vectorstores import Weaviate
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
Expand All @@ -23,7 +23,9 @@
Runnable,
RunnableBranch,
RunnableLambda,
RunnableMap,
RunnablePassthrough,
RunnableSequence,
chain,
)
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
Expand Down Expand Up @@ -65,6 +67,32 @@
user.\
"""

COHERE_RESPONSE_TEMPLATE = """\
You are an expert programmer and problem-solver, tasked with answering any question \
about Langchain.
Generate a comprehensive and informative answer of 80 words or less for the \
given question based solely on the provided search results (URL and content). You must \
only use information from the provided search results. Use an unbiased and \
journalistic tone. Combine search results together into a coherent answer. Do not \
repeat text. Cite search results using [${{number}}] notation. Only cite the most \
relevant results that answer the question accurately. Place these citations at the end \
of the sentence or paragraph that reference them - do not put them all at the end. If \
different results refer to different entities within the same name, write separate \
answers for each entity.
You should use bullet points in your answer for readability. Put citations where they apply
rather than putting them all at the end.
If there is nothing in the context relevant to the question at hand, just say "Hmm, \
I'm not sure." Don't try to make up an answer.
REMEMBER: If there is no relevant information within the context, just say "Hmm, I'm \
not sure." Don't try to make up an answer. Anything between the preceding 'context' \
html blocks is retrieved from a knowledge bank, not part of the conversation with the \
user.\
"""

REPHRASE_TEMPLATE = """\
Given the following conversation and a follow up question, rephrase the follow up \
question to be a standalone question.
Expand Down Expand Up @@ -114,7 +142,7 @@ def get_retriever() -> BaseRetriever:


def create_retriever_chain(
llm: BaseLanguageModel, retriever: BaseRetriever
llm: LanguageModelLike, retriever: BaseRetriever
) -> Runnable:
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE)
condense_question_chain = (
Expand Down Expand Up @@ -158,42 +186,51 @@ def serialize_history(request: ChatRequest):
return converted_chat_history


def create_chain(
llm: BaseLanguageModel,
retriever: BaseRetriever,
) -> Runnable:
def create_chain(llm: LanguageModelLike, retriever: BaseRetriever) -> Runnable:
retriever_chain = create_retriever_chain(
llm,
retriever,
).with_config(run_name="FindDocs")
_context = RunnableMap(
{
"context": retriever_chain | format_docs,
"question": itemgetter("question"),
"chat_history": itemgetter("chat_history"),
}
).with_config(run_name="RetrieveDocs")
context = (
RunnablePassthrough.assign(docs=retriever_chain)
.assign(context=lambda x: format_docs(x["docs"]))
.with_config(run_name="RetrieveDocs")
)
prompt = ChatPromptTemplate.from_messages(
[
("system", RESPONSE_TEMPLATE),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)
default_response_synthesizer = prompt | llm

response_synthesizer = (prompt | llm | StrOutputParser()).with_config(
run_name="GenerateResponse",
cohere_prompt = ChatPromptTemplate.from_messages(
[
("system", COHERE_RESPONSE_TEMPLATE),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)

@chain
def cohere_response_synthesizer(input: dict) -> RunnableSequence:
return cohere_prompt | llm.bind(source_documents=input["docs"])

response_synthesizer = (
default_response_synthesizer.configurable_alternatives(
ConfigurableField("llm"),
default_key="openai_gpt_3_5_turbo",
anthropic_claude_2_1=default_response_synthesizer,
fireworks_mixtral=default_response_synthesizer,
google_gemini_pro=default_response_synthesizer,
cohere_command=cohere_response_synthesizer,
)
| StrOutputParser()
).with_config(run_name="GenerateResponse")
return (
{
"question": RunnableLambda(itemgetter("question")).with_config(
run_name="Itemgetter:question"
),
"chat_history": RunnableLambda(serialize_history).with_config(
run_name="SerializeHistory"
),
}
| _context
RunnablePassthrough.assign(chat_history=serialize_history)
| context
| response_synthesizer
)

Expand Down Expand Up @@ -234,7 +271,4 @@ def create_chain(
)

retriever = get_retriever()
answer_chain = create_chain(
llm,
retriever,
)
answer_chain = create_chain(llm, retriever)

0 comments on commit 24ffeb0

Please sign in to comment.