From 08c60a70c2e53d6c37af6936f4814a5898d6c7ec Mon Sep 17 00:00:00 2001 From: Nathan Date: Thu, 20 Jun 2024 16:34:50 +0700 Subject: [PATCH] add time weighted retrieval (#2908) * add time weighted retrieval * add missing configuration for timeWeightedVectorStore * resolving conflict * add missing configuration for timeWeightedVectorStore * resolving conflict * fix linting issues * fix build failed due to requirement for useTimeWeightedRetriever in AssistantTool * update web packages complying the new structure --------- Co-authored-by: thu --- core/src/types/assistant/assistantEntity.ts | 1 + extensions/assistant-extension/src/index.ts | 1 + .../assistant-extension/src/node/index.ts | 12 ++- .../assistant-extension/src/node/retrieval.ts | 52 +++++++++++- .../src/tools/retrieval.ts | 6 +- .../Thread/ThreadRightPanel/Tools/index.tsx | 80 ++++++++++++++++--- 6 files changed, 135 insertions(+), 17 deletions(-) diff --git a/core/src/types/assistant/assistantEntity.ts b/core/src/types/assistant/assistantEntity.ts index 733dbea8d8..27592e26b6 100644 --- a/core/src/types/assistant/assistantEntity.ts +++ b/core/src/types/assistant/assistantEntity.ts @@ -6,6 +6,7 @@ export type AssistantTool = { type: string enabled: boolean + useTimeWeightedRetriever?: boolean settings: any } diff --git a/extensions/assistant-extension/src/index.ts b/extensions/assistant-extension/src/index.ts index 53d3ed0d5d..12441995ee 100644 --- a/extensions/assistant-extension/src/index.ts +++ b/extensions/assistant-extension/src/index.ts @@ -126,6 +126,7 @@ export default class JanAssistantExtension extends AssistantExtension { { type: 'retrieval', enabled: false, + useTimeWeightedRetriever: false, settings: { top_k: 2, chunk_size: 1024, diff --git a/extensions/assistant-extension/src/node/index.ts b/extensions/assistant-extension/src/node/index.ts index 46835614d4..83a4a19831 100644 --- a/extensions/assistant-extension/src/node/index.ts +++ b/extensions/assistant-extension/src/node/index.ts @@ -11,13 +11,14 @@ export function toolRetrievalUpdateTextSplitter( export async function toolRetrievalIngestNewDocument( file: string, model: string, - engine: string + engine: string, + useTimeWeighted: boolean ) { const filePath = path.join(getJanDataFolderPath(), normalizeFilePath(file)) const threadPath = path.dirname(filePath.replace('files', '')) retrieval.updateEmbeddingEngine(model, engine) return retrieval - .ingestAgentKnowledge(filePath, `${threadPath}/memory`) + .ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted) .catch((err) => { console.error(err) }) @@ -33,8 +34,11 @@ export async function toolRetrievalLoadThreadMemory(threadId: string) { }) } -export async function toolRetrievalQueryResult(query: string) { - return retrieval.generateResult(query).catch((err) => { +export async function toolRetrievalQueryResult( + query: string, + useTimeWeighted: boolean = false +) { + return retrieval.generateResult(query, useTimeWeighted).catch((err) => { console.error(err) }) } diff --git a/extensions/assistant-extension/src/node/retrieval.ts b/extensions/assistant-extension/src/node/retrieval.ts index 52193f221c..28d629aa80 100644 --- a/extensions/assistant-extension/src/node/retrieval.ts +++ b/extensions/assistant-extension/src/node/retrieval.ts @@ -2,11 +2,16 @@ import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter' import { formatDocumentsAsString } from 'langchain/util/document' import { PDFLoader } from 'langchain/document_loaders/fs/pdf' +import { TimeWeightedVectorStoreRetriever } from 'langchain/retrievers/time_weighted' +import { MemoryVectorStore } from 'langchain/vectorstores/memory' + import { HNSWLib } from 'langchain/vectorstores/hnswlib' import { OpenAIEmbeddings } from 'langchain/embeddings/openai' import { readEmbeddingEngine } from './engine' +import path from 'path' + export class Retrieval { public chunkSize: number = 100 public chunkOverlap?: number = 0 @@ -15,8 +20,25 @@ export class Retrieval { private embeddingModel?: OpenAIEmbeddings = undefined private textSplitter?: RecursiveCharacterTextSplitter + // to support time-weighted retrieval + private timeWeightedVectorStore: MemoryVectorStore + private timeWeightedretriever: any | TimeWeightedVectorStoreRetriever + constructor(chunkSize: number = 4000, chunkOverlap: number = 200) { this.updateTextSplitter(chunkSize, chunkOverlap) + + // declare time-weighted retriever and storage + this.timeWeightedVectorStore = new MemoryVectorStore( + new OpenAIEmbeddings( + { openAIApiKey: 'nitro-embedding' }, + { basePath: 'http://127.0.0.1:3928/v1' } + ) + ) + this.timeWeightedretriever = new TimeWeightedVectorStoreRetriever({ + vectorStore: this.timeWeightedVectorStore, + memoryStream: [], + searchKwargs: 2, + }) } public updateTextSplitter(chunkSize: number, chunkOverlap: number): void { @@ -44,11 +66,15 @@ export class Retrieval { openAIApiKey: settings.api_key, }) } + + // update time-weighted embedding model + this.timeWeightedVectorStore.embeddings = this.embeddingModel } public ingestAgentKnowledge = async ( filePath: string, - memoryPath: string + memoryPath: string, + useTimeWeighted: boolean ): Promise => { const loader = new PDFLoader(filePath, { splitPages: true, @@ -57,6 +83,13 @@ export class Retrieval { const doc = await loader.load() const docs = await this.textSplitter!.splitDocuments(doc) const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel) + + // add documents with metadata by using the time-weighted retriever in order to support time-weighted retrieval + if (useTimeWeighted && this.timeWeightedretriever) { + await ( + this.timeWeightedretriever as TimeWeightedVectorStoreRetriever + ).addDocuments(docs) + } return vectorStore.save(memoryPath) } @@ -67,10 +100,25 @@ export class Retrieval { return Promise.resolve() } - public generateResult = async (query: string): Promise => { + public generateResult = async ( + query: string, + useTimeWeighted: boolean + ): Promise => { + if (useTimeWeighted) { + if (!this.timeWeightedretriever) { + return Promise.resolve(' ') + } + // use invoke because getRelevantDocuments is deprecated + const relevantDocs = await this.timeWeightedretriever.invoke(query) + const serializedDoc = formatDocumentsAsString(relevantDocs) + return Promise.resolve(serializedDoc) + } + if (!this.retriever) { return Promise.resolve(' ') } + + // should use invoke(query) because getRelevantDocuments is deprecated const relevantDocs = await this.retriever.getRelevantDocuments(query) const serializedDoc = formatDocumentsAsString(relevantDocs) return Promise.resolve(serializedDoc) diff --git a/extensions/assistant-extension/src/tools/retrieval.ts b/extensions/assistant-extension/src/tools/retrieval.ts index a1a641941f..7631922871 100644 --- a/extensions/assistant-extension/src/tools/retrieval.ts +++ b/extensions/assistant-extension/src/tools/retrieval.ts @@ -37,7 +37,8 @@ export class RetrievalTool extends InferenceTool { 'toolRetrievalIngestNewDocument', docFile, data.model?.id, - data.model?.engine + data.model?.engine, + tool?.useTimeWeightedRetriever ?? false ) } else { return Promise.resolve(data) @@ -78,7 +79,8 @@ export class RetrievalTool extends InferenceTool { const retrievalResult = await executeOnMain( NODE, 'toolRetrievalQueryResult', - prompt + prompt, + tool?.useTimeWeightedRetriever ?? false ) console.debug('toolRetrievalQueryResult', retrievalResult) diff --git a/web/screens/Thread/ThreadRightPanel/Tools/index.tsx b/web/screens/Thread/ThreadRightPanel/Tools/index.tsx index 428cfbf9c8..7faecc08ae 100644 --- a/web/screens/Thread/ThreadRightPanel/Tools/index.tsx +++ b/web/screens/Thread/ThreadRightPanel/Tools/index.tsx @@ -66,6 +66,32 @@ const Tools = () => { [activeThread, updateThreadMetadata] ) + const onTimeWeightedRetrieverSwitchUpdate = useCallback( + (enabled: boolean) => { + if (!activeThread) return + updateThreadMetadata({ + ...activeThread, + assistants: [ + { + ...activeThread.assistants[0], + tools: [ + { + type: 'retrieval', + enabled: true, + useTimeWeightedRetriever: enabled, + settings: + (activeThread.assistants[0].tools && + activeThread.assistants[0].tools[0]?.settings) ?? + {}, + }, + ], + }, + ], + }) + }, + [activeThread, updateThreadMetadata] + ) + if (!experimentalFeature) return null return ( @@ -143,6 +169,46 @@ const Tools = () => { className="inline-block font-medium" > Vector Database + + } + content="Vector Database is crucial for efficient storage + and retrieval of embeddings. Consider your + specific task, available resources, and language + requirements. Experiment to find the best fit for + your specific use case." + /> + +
+ + onTimeWeightedRetrieverSwitchUpdate(e.target.checked) + } + /> +
+ + +
+ +
+ +
+
+ { className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]" /> } - content="Vector Database is crucial for efficient storage - and retrieval of embeddings. Consider your - specific task, available resources, and language - requirements. Experiment to find the best fit for - your specific use case." + content="Time-Weighted Retriever looks at how similar + they are and how new they are. It compares + documents based on their meaning like usual, but + also considers when they were added to give + newer ones more importance." />
- -
- -