Skip to content

Commit

Permalink
Format (langchain-ai#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Oct 1, 2023
1 parent 2e519f8 commit bf1ec0c
Show file tree
Hide file tree
Showing 19 changed files with 387 additions and 200 deletions.
3 changes: 3 additions & 0 deletions chat-langchain/.prettierrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"endOfLine": "lf"
}
145 changes: 97 additions & 48 deletions chat-langchain/app/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ import { RunnableSequence, RunnableMap } from "langchain/schema/runnable";
import { HumanMessage, AIMessage, BaseMessage } from "langchain/schema";
import { ChatOpenAI } from "langchain/chat_models/openai";
import { StringOutputParser } from "langchain/schema/output_parser";
import { PromptTemplate, ChatPromptTemplate, MessagesPlaceholder } from "langchain/prompts";
import {
PromptTemplate,
ChatPromptTemplate,
MessagesPlaceholder,
} from "langchain/prompts";

import weaviate from "weaviate-ts-client";
import { WeaviateStore } from "langchain/vectorstores/weaviate";
Expand Down Expand Up @@ -39,72 +43,93 @@ const getRetriever = async () => {
const client = weaviate.client({
scheme: "https",
host: process.env.WEAVIATE_HOST!,
apiKey: new weaviate.ApiKey(
process.env.WEAVIATE_API_KEY!
),
});
const vectorstore = await WeaviateStore.fromExistingIndex(new OpenAIEmbeddings({}), {
client,
indexName: process.env.WEAVIATE_INDEX_NAME!,
textKey: "text",
metadataKeys: ["source", "title"],
apiKey: new weaviate.ApiKey(process.env.WEAVIATE_API_KEY!),
});
const vectorstore = await WeaviateStore.fromExistingIndex(
new OpenAIEmbeddings({}),
{
client,
indexName: process.env.WEAVIATE_INDEX_NAME!,
textKey: "text",
metadataKeys: ["source", "title"],
},
);
return vectorstore.asRetriever({ k: 6 });
};

const createRetrieverChain = (llm: BaseLanguageModel, retriever: BaseRetriever, useChatHistory: boolean) => {
const createRetrieverChain = (
llm: BaseLanguageModel,
retriever: BaseRetriever,
useChatHistory: boolean,
) => {
// Small speed/accuracy optimization: no need to rephrase the first question
// since there shouldn't be any meta-references to prior chat history
if (!useChatHistory) {
return RunnableSequence.from([
({ question }) => question,
retriever
]);
return RunnableSequence.from([({ question }) => question, retriever]);
} else {
const CONDENSE_QUESTION_PROMPT = PromptTemplate.fromTemplate(REPHRASE_TEMPLATE);
const CONDENSE_QUESTION_PROMPT =
PromptTemplate.fromTemplate(REPHRASE_TEMPLATE);
const condenseQuestionChain = RunnableSequence.from([
CONDENSE_QUESTION_PROMPT,
llm,
new StringOutputParser()
new StringOutputParser(),
]).withConfig({
tags: ["CondenseQuestion"]
tags: ["CondenseQuestion"],
});
return condenseQuestionChain.pipe(retriever);
}
};

const formatDocs = (docs: Document[]) => {
return docs.map((doc, i) => `<doc id='${i}'>${doc.pageContent}</doc>`).join("\n");
return docs
.map((doc, i) => `<doc id='${i}'>${doc.pageContent}</doc>`)
.join("\n");
};

const formatChatHistoryAsString = (history: BaseMessage[]) => {
return history.map((message) => `${message._getType()}: ${message.content}`).join('\n');
}
return history
.map((message) => `${message._getType()}: ${message.content}`)
.join("\n");
};

const createChain = (llm: BaseLanguageModel, retriever: BaseRetriever, useChatHistory: boolean) => {
const retrieverChain = createRetrieverChain(llm, retriever, useChatHistory).withConfig({ tags: ["FindDocs"] });
const createChain = (
llm: BaseLanguageModel,
retriever: BaseRetriever,
useChatHistory: boolean,
) => {
const retrieverChain = createRetrieverChain(
llm,
retriever,
useChatHistory,
).withConfig({ tags: ["FindDocs"] });
const context = new RunnableMap({
steps: {
context: RunnableSequence.from([
({question, chat_history}) => ({question, chat_history: formatChatHistoryAsString(chat_history)}),
({ question, chat_history }) => ({
question,
chat_history: formatChatHistoryAsString(chat_history),
}),
retrieverChain,
formatDocs
formatDocs,
]),
question: ({ question }) => question,
chat_history: ({ chat_history }) => chat_history
}
chat_history: ({ chat_history }) => chat_history,
},
}).withConfig({ tags: ["RetrieveDocs"] });
const prompt = ChatPromptTemplate.fromMessages([
["system", RESPONSE_TEMPLATE],
new MessagesPlaceholder("chat_history"),
["human", "{question}"],
]);

const responseSynthesizerChain = prompt.pipe(llm).pipe(new StringOutputParser()).withConfig({
tags: ["GenerateResponse"],
});
const responseSynthesizerChain = prompt
.pipe(llm)
.pipe(new StringOutputParser())
.withConfig({
tags: ["GenerateResponse"],
});
return context.pipe(responseSynthesizerChain);
}
};

export async function POST(req: NextRequest) {
try {
Expand All @@ -114,15 +139,22 @@ export async function POST(req: NextRequest) {
const conversationId = body.conversation_id;

if (question === undefined || typeof question !== "string") {
return NextResponse.json({ error: `Invalid "message" parameter.` }, { status: 400 });
return NextResponse.json(
{ error: `Invalid "message" parameter.` },
{ status: 400 },
);
}

const convertedChatHistory = [];
for (const historyMessage of chatHistory) {
if (historyMessage.human) {
convertedChatHistory.push(new HumanMessage({ content: historyMessage.human }));
convertedChatHistory.push(
new HumanMessage({ content: historyMessage.human }),
);
} else if (historyMessage.ai) {
convertedChatHistory.push(new AIMessage({ content: historyMessage.ai }));
convertedChatHistory.push(
new AIMessage({ content: historyMessage.ai }),
);
}
}

Expand All @@ -132,7 +164,11 @@ export async function POST(req: NextRequest) {
temperature: 0,
});
const retriever = await getRetriever();
const answerChain = createChain(llm, retriever, !!convertedChatHistory.length);
const answerChain = createChain(
llm,
retriever,
!!convertedChatHistory.length,
);

/**
* Narrows streamed log output down to final output and the FindDocs tagged chain to
Expand All @@ -142,14 +178,18 @@ export async function POST(req: NextRequest) {
* you can pass directly to the Response as well:
* https://js.langchain.com/docs/expression_language/interface#stream
*/
const stream = await answerChain.streamLog({
question,
chat_history: convertedChatHistory,
}, {
metadata
}, {
includeTags: ["FindDocs"],
});
const stream = await answerChain.streamLog(
{
question,
chat_history: convertedChatHistory,
},
{
metadata,
},
{
includeTags: ["FindDocs"],
},
);

// Only return a selection of output to the frontend
const textEncoder = new TextEncoder();
Expand All @@ -162,24 +202,33 @@ export async function POST(req: NextRequest) {
let hasEnqueued = false;
for (const op of value.ops) {
if ("value" in op) {
if (op.path === "/logs/0/final_output" && Array.isArray(op.value.output)) {
if (
op.path === "/logs/0/final_output" &&
Array.isArray(op.value.output)
) {
const allSources = op.value.output.map((doc: Document) => {
return {
url: doc.metadata.source,
title: doc.metadata.title,
}
};
});
if (allSources.length) {
const chunk = textEncoder.encode(JSON.stringify({ sources: allSources }) + "\n");
const chunk = textEncoder.encode(
JSON.stringify({ sources: allSources }) + "\n",
);
controller.enqueue(chunk);
hasEnqueued = true;
}
} else if (op.path === "/streamed_output/-") {
const chunk = textEncoder.encode(JSON.stringify({tok: op.value}) + "\n");
const chunk = textEncoder.encode(
JSON.stringify({ tok: op.value }) + "\n",
);
controller.enqueue(chunk);
hasEnqueued = true;
} else if (op.path === "" && op.op === "replace") {
const chunk = textEncoder.encode(JSON.stringify({run_id: op.value.id}) + "\n");
const chunk = textEncoder.encode(
JSON.stringify({ run_id: op.value.id }) + "\n",
);
controller.enqueue(chunk);
hasEnqueued = true;
}
Expand Down
20 changes: 16 additions & 4 deletions chat-langchain/app/api/feedback/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@ export async function POST(req: NextRequest) {
const body = await req.json();
const { run_id, key = "user_score", ...rest } = body;
if (!run_id) {
return NextResponse.json({ error: "No LangSmith run ID provided" }, { status: 400 });
return NextResponse.json(
{ error: "No LangSmith run ID provided" },
{ status: 400 },
);
}

await client.createFeedback(run_id, key, rest);

return NextResponse.json({ result: "posted feedback successfully" }, { status: 200 });
return NextResponse.json(
{ result: "posted feedback successfully" },
{ status: 200 },
);
} catch (e: any) {
console.log(e);
return NextResponse.json({ error: e.message }, { status: 500 });
Expand All @@ -30,12 +36,18 @@ export async function PATCH(req: NextRequest) {
const body = await req.json();
const { feedback_id, score, comment } = body;
if (feedback_id === undefined) {
return NextResponse.json({ error: "No feedback ID provided" }, { status: 400 });
return NextResponse.json(
{ error: "No feedback ID provided" },
{ status: 400 },
);
}

await client.updateFeedback(feedback_id, { score, comment });

return NextResponse.json({ result: "patched feedback successfully" }, { status: 200 });
return NextResponse.json(
{ result: "patched feedback successfully" },
{ status: 200 },
);
} catch (e: any) {
console.log(e);
return NextResponse.json({ error: e.message }, { status: 500 });
Expand Down
11 changes: 8 additions & 3 deletions chat-langchain/app/api/get_trace/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ export const runtime = "edge";
const client = new Client();

const pollForRun = async (runId: string, retryCount = 0): Promise<string> => {
await new Promise((resolve) => setTimeout(resolve, retryCount * retryCount * 100));
await new Promise((resolve) =>
setTimeout(resolve, retryCount * retryCount * 100),
);
try {
await client.readRun(runId);
} catch (e) {
Expand All @@ -24,14 +26,17 @@ const pollForRun = async (runId: string, retryCount = 0): Promise<string> => {
} catch (e) {
return client.shareRun(runId);
}
}
};

export async function POST(req: NextRequest) {
try {
const body = await req.json();
const { run_id } = body;
if (run_id === undefined) {
return NextResponse.json({ error: "No run ID provided" }, { status: 400 });
return NextResponse.json(
{ error: "No run ID provided" },
{ status: 400 },
);
}
const response = await pollForRun(run_id);
return NextResponse.json(response, { status: 200 });
Expand Down
Loading

0 comments on commit bf1ec0c

Please sign in to comment.