Skip to content

Commit

Permalink
add time weighted retrieval (janhq#2908)
Browse files Browse the repository at this point in the history
* add time weighted retrieval

* add missing configuration for timeWeightedVectorStore

* resolving conflict

* add missing configuration for timeWeightedVectorStore

* resolving conflict

* fix linting issues

* fix build failed due to requirement for useTimeWeightedRetriever in AssistantTool

* update web packages complying the new structure

---------

Co-authored-by: thu <[email protected]>
  • Loading branch information
2 people authored and louis-jan committed Jul 12, 2024
1 parent 8077eb5 commit 08c60a7
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 17 deletions.
1 change: 1 addition & 0 deletions core/src/types/assistant/assistantEntity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
export type AssistantTool = {
type: string
enabled: boolean
useTimeWeightedRetriever?: boolean
settings: any
}

Expand Down
1 change: 1 addition & 0 deletions extensions/assistant-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ export default class JanAssistantExtension extends AssistantExtension {
{
type: 'retrieval',
enabled: false,
useTimeWeightedRetriever: false,
settings: {
top_k: 2,
chunk_size: 1024,
Expand Down
12 changes: 8 additions & 4 deletions extensions/assistant-extension/src/node/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ export function toolRetrievalUpdateTextSplitter(
export async function toolRetrievalIngestNewDocument(
file: string,
model: string,
engine: string
engine: string,
useTimeWeighted: boolean
) {
const filePath = path.join(getJanDataFolderPath(), normalizeFilePath(file))
const threadPath = path.dirname(filePath.replace('files', ''))
retrieval.updateEmbeddingEngine(model, engine)
return retrieval
.ingestAgentKnowledge(filePath, `${threadPath}/memory`)
.ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted)
.catch((err) => {
console.error(err)
})
Expand All @@ -33,8 +34,11 @@ export async function toolRetrievalLoadThreadMemory(threadId: string) {
})
}

export async function toolRetrievalQueryResult(query: string) {
return retrieval.generateResult(query).catch((err) => {
export async function toolRetrievalQueryResult(
query: string,
useTimeWeighted: boolean = false
) {
return retrieval.generateResult(query, useTimeWeighted).catch((err) => {
console.error(err)
})
}
52 changes: 50 additions & 2 deletions extensions/assistant-extension/src/node/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@ import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'
import { formatDocumentsAsString } from 'langchain/util/document'
import { PDFLoader } from 'langchain/document_loaders/fs/pdf'

import { TimeWeightedVectorStoreRetriever } from 'langchain/retrievers/time_weighted'
import { MemoryVectorStore } from 'langchain/vectorstores/memory'

import { HNSWLib } from 'langchain/vectorstores/hnswlib'

import { OpenAIEmbeddings } from 'langchain/embeddings/openai'
import { readEmbeddingEngine } from './engine'

import path from 'path'

export class Retrieval {
public chunkSize: number = 100
public chunkOverlap?: number = 0
Expand All @@ -15,8 +20,25 @@ export class Retrieval {
private embeddingModel?: OpenAIEmbeddings = undefined
private textSplitter?: RecursiveCharacterTextSplitter

// to support time-weighted retrieval
private timeWeightedVectorStore: MemoryVectorStore
private timeWeightedretriever: any | TimeWeightedVectorStoreRetriever

constructor(chunkSize: number = 4000, chunkOverlap: number = 200) {
this.updateTextSplitter(chunkSize, chunkOverlap)

// declare time-weighted retriever and storage
this.timeWeightedVectorStore = new MemoryVectorStore(
new OpenAIEmbeddings(
{ openAIApiKey: 'nitro-embedding' },
{ basePath: 'http://127.0.0.1:3928/v1' }
)
)
this.timeWeightedretriever = new TimeWeightedVectorStoreRetriever({
vectorStore: this.timeWeightedVectorStore,
memoryStream: [],
searchKwargs: 2,
})
}

public updateTextSplitter(chunkSize: number, chunkOverlap: number): void {
Expand Down Expand Up @@ -44,11 +66,15 @@ export class Retrieval {
openAIApiKey: settings.api_key,
})
}

// update time-weighted embedding model
this.timeWeightedVectorStore.embeddings = this.embeddingModel
}

public ingestAgentKnowledge = async (
filePath: string,
memoryPath: string
memoryPath: string,
useTimeWeighted: boolean
): Promise<any> => {
const loader = new PDFLoader(filePath, {
splitPages: true,
Expand All @@ -57,6 +83,13 @@ export class Retrieval {
const doc = await loader.load()
const docs = await this.textSplitter!.splitDocuments(doc)
const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel)

// add documents with metadata by using the time-weighted retriever in order to support time-weighted retrieval
if (useTimeWeighted && this.timeWeightedretriever) {
await (
this.timeWeightedretriever as TimeWeightedVectorStoreRetriever
).addDocuments(docs)
}
return vectorStore.save(memoryPath)
}

Expand All @@ -67,10 +100,25 @@ export class Retrieval {
return Promise.resolve()
}

public generateResult = async (query: string): Promise<string> => {
public generateResult = async (
query: string,
useTimeWeighted: boolean
): Promise<string> => {
if (useTimeWeighted) {
if (!this.timeWeightedretriever) {
return Promise.resolve(' ')
}
// use invoke because getRelevantDocuments is deprecated
const relevantDocs = await this.timeWeightedretriever.invoke(query)
const serializedDoc = formatDocumentsAsString(relevantDocs)
return Promise.resolve(serializedDoc)
}

if (!this.retriever) {
return Promise.resolve(' ')
}

// should use invoke(query) because getRelevantDocuments is deprecated
const relevantDocs = await this.retriever.getRelevantDocuments(query)
const serializedDoc = formatDocumentsAsString(relevantDocs)
return Promise.resolve(serializedDoc)
Expand Down
6 changes: 4 additions & 2 deletions extensions/assistant-extension/src/tools/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ export class RetrievalTool extends InferenceTool {
'toolRetrievalIngestNewDocument',
docFile,
data.model?.id,
data.model?.engine
data.model?.engine,
tool?.useTimeWeightedRetriever ?? false
)
} else {
return Promise.resolve(data)
Expand Down Expand Up @@ -78,7 +79,8 @@ export class RetrievalTool extends InferenceTool {
const retrievalResult = await executeOnMain(
NODE,
'toolRetrievalQueryResult',
prompt
prompt,
tool?.useTimeWeightedRetriever ?? false
)
console.debug('toolRetrievalQueryResult', retrievalResult)

Expand Down
80 changes: 71 additions & 9 deletions web/screens/Thread/ThreadRightPanel/Tools/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,32 @@ const Tools = () => {
[activeThread, updateThreadMetadata]
)

const onTimeWeightedRetrieverSwitchUpdate = useCallback(
(enabled: boolean) => {
if (!activeThread) return
updateThreadMetadata({
...activeThread,
assistants: [
{
...activeThread.assistants[0],
tools: [
{
type: 'retrieval',
enabled: true,
useTimeWeightedRetriever: enabled,
settings:
(activeThread.assistants[0].tools &&
activeThread.assistants[0].tools[0]?.settings) ??
{},
},
],
},
],
})
},
[activeThread, updateThreadMetadata]
)

if (!experimentalFeature) return null

return (
Expand Down Expand Up @@ -143,6 +169,46 @@ const Tools = () => {
className="inline-block font-medium"
>
Vector Database
<Tooltip
trigger={
<InfoIcon
size={16}
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
/>
}
content="Vector Database is crucial for efficient storage
and retrieval of embeddings. Consider your
specific task, available resources, and language
requirements. Experiment to find the best fit for
your specific use case."
/>
</label>
<div className="ml-auto flex items-center justify-between">
<Switch
name="use-time-weighted-retriever"
className="mr-2"
checked={
activeThread?.assistants[0].tools[0]
.useTimeWeightedRetriever || false
}
onChange={(e) =>
onTimeWeightedRetrieverSwitchUpdate(e.target.checked)
}
/>
</div>
</div>

<div className="w-full">
<Input value="HNSWLib" disabled readOnly />
</div>
</div>
<div className="mb-4">
<div className="mb-2 flex items-center">
<label
id="use-time-weighted-retriever"
className="inline-block font-medium"
>
Time-Weighted Retrieval?
</label>
<Tooltip
trigger={
Expand All @@ -151,17 +217,13 @@ const Tools = () => {
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
/>
}
content="Vector Database is crucial for efficient storage
and retrieval of embeddings. Consider your
specific task, available resources, and language
requirements. Experiment to find the best fit for
your specific use case."
content="Time-Weighted Retriever looks at how similar
they are and how new they are. It compares
documents based on their meaning like usual, but
also considers when they were added to give
newer ones more importance."
/>
</div>

<div className="w-full">
<Input value="HNSWLib" disabled readOnly />
</div>
</div>
<AssistantSetting
componentData={componentDataAssistantSetting}
Expand Down

0 comments on commit 08c60a7

Please sign in to comment.