Skip to content

Commit

Permalink
refactor: introduce inference tools (janhq#2493)
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan authored Mar 25, 2024
1 parent 14a6746 commit 8e8dfd4
Show file tree
Hide file tree
Showing 15 changed files with 240 additions and 193 deletions.
2 changes: 1 addition & 1 deletion core/src/browser/extensions/engines/AIEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export abstract class AIEngine extends BaseExtension {
* Registers AI Engines
*/
registerEngine() {
EngineManager.instance()?.register(this)
EngineManager.instance().register(this)
}

/**
Expand Down
7 changes: 5 additions & 2 deletions core/src/browser/extensions/engines/EngineManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ export class EngineManager {
return this.engines.get(provider) as T | undefined
}

static instance(): EngineManager | undefined {
return window.core?.engineManager as EngineManager
/**
* The instance of the engine manager.
*/
static instance(): EngineManager {
return window.core?.engineManager as EngineManager ?? new EngineManager()
}
}
1 change: 1 addition & 0 deletions core/src/browser/extensions/engines/OAIEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ export abstract class OAIEngine extends AIEngine {
return
}
message.status = MessageStatus.Error
message.error_code = err.code
events.emit(MessageEvent.OnMessageUpdate, message)
},
})
Expand Down
12 changes: 11 additions & 1 deletion core/src/browser/extensions/engines/helpers/sse.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Observable } from 'rxjs'
import { ModelRuntimeParams } from '../../../../types'
import { ErrorCode, ModelRuntimeParams } from '../../../../types'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
Expand Down Expand Up @@ -34,6 +34,16 @@ export function requestInference(
signal: controller?.signal,
})
.then(async (response) => {
if (!response.ok) {
const data = await response.json()
const error = {
message: data.error?.message ?? 'Error occurred.',
code: data.error?.code ?? ErrorCode.Unknown,
}
subscriber.error(error)
subscriber.complete()
return
}
if (model.parameters.stream === false) {
const data = await response.json()
subscriber.next(data.choices[0]?.message?.content ?? '')
Expand Down
6 changes: 6 additions & 0 deletions core/src/browser/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ export * from './extension'
* @module
*/
export * from './extensions'

/**
* Export all base tools.
* @module
*/
export * from './tools'
2 changes: 2 additions & 0 deletions core/src/browser/tools/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export * from './manager'
export * from './tool'
47 changes: 47 additions & 0 deletions core/src/browser/tools/manager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import { AssistantTool, MessageRequest } from '../../types'
import { InferenceTool } from './tool'

/**
* Manages the registration and retrieval of inference tools.
*/
export class ToolManager {
public tools = new Map<string, InferenceTool>()

/**
* Registers a tool.
* @param tool - The tool to register.
*/
register<T extends InferenceTool>(tool: T) {
this.tools.set(tool.name, tool)
}

/**
* Retrieves a tool by it's name.
* @param name - The name of the tool to retrieve.
* @returns The tool, if found.
*/
get<T extends InferenceTool>(name: string): T | undefined {
return this.tools.get(name) as T | undefined
}

/*
** Process the message request with the tools.
*/
process(request: MessageRequest, tools: AssistantTool[]): Promise<MessageRequest> {
return tools.reduce((prevPromise, currentTool) => {
return prevPromise.then((prevResult) => {
return currentTool.enabled
? this.get(currentTool.type)?.process(prevResult, currentTool) ??
Promise.resolve(prevResult)
: Promise.resolve(prevResult)
})
}, Promise.resolve(request))
}

/**
* The instance of the tool manager.
*/
static instance(): ToolManager {
return (window.core?.toolManager as ToolManager) ?? new ToolManager()
}
}
12 changes: 12 additions & 0 deletions core/src/browser/tools/tool.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { AssistantTool, MessageRequest } from '../../types'

/**
* Represents a base inference tool.
*/
export abstract class InferenceTool {
abstract name: string
/*
** Process a message request and return the processed message request.
*/
abstract process(request: MessageRequest, tool?: AssistantTool): Promise<MessageRequest>
}
5 changes: 0 additions & 5 deletions core/src/types/model/modelEntity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ export type ModelInfo = {
settings: ModelSettingParams
parameters: ModelRuntimeParams
engine?: InferenceEngine
proxy_model?: InferenceEngine
}

/**
Expand All @@ -21,8 +20,6 @@ export enum InferenceEngine {
groq = 'groq',
triton_trtllm = 'triton_trtllm',
nitro_tensorrt_llm = 'nitro-tensorrt-llm',

tool_retrieval_enabled = 'tool_retrieval_enabled',
}

export type ModelArtifact = {
Expand Down Expand Up @@ -94,8 +91,6 @@ export type Model = {
* The model engine.
*/
engine: InferenceEngine

proxy_model?: InferenceEngine
}

export type ModelMetadata = {
Expand Down
149 changes: 5 additions & 144 deletions extensions/assistant-extension/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
import {
fs,
Assistant,
MessageRequest,
events,
InferenceEngine,
MessageEvent,
InferenceEvent,
joinPath,
executeOnMain,
AssistantExtension,
AssistantEvent,
ToolManager,
} from '@janhq/core'
import { RetrievalTool } from './tools/retrieval'

export default class JanAssistantExtension extends AssistantExtension {
private static readonly _homeDir = 'file://assistants'
private static readonly _threadDir = 'file://threads'

controller = new AbortController()
isCancelled = false
retrievalThreadId: string | undefined = undefined

async onLoad() {
// Register the retrieval tool
ToolManager.instance().register(new RetrievalTool())

// making the assistant directory
const assistantDirExist = await fs.existsSync(
JanAssistantExtension._homeDir
Expand All @@ -38,140 +33,6 @@ export default class JanAssistantExtension extends AssistantExtension {
// Update the assistant list
events.emit(AssistantEvent.OnAssistantsUpdate, {})
}

// Events subscription
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
JanAssistantExtension.handleMessageRequest(data, this)
)

events.on(InferenceEvent.OnInferenceStopped, () => {
JanAssistantExtension.handleInferenceStopped(this)
})
}

private static async handleInferenceStopped(instance: JanAssistantExtension) {
instance.isCancelled = true
instance.controller?.abort()
}

private static async handleMessageRequest(
data: MessageRequest,
instance: JanAssistantExtension
) {
instance.isCancelled = false
instance.controller = new AbortController()

if (
data.model?.engine !== InferenceEngine.tool_retrieval_enabled ||
!data.messages ||
// TODO: Since the engine is defined, its unsafe to assume that assistant tools are defined
// That could lead to an issue where thread stuck at generating response
!data.thread?.assistants[0]?.tools
) {
return
}

const latestMessage = data.messages[data.messages.length - 1]

// 1. Ingest the document if needed
if (
latestMessage &&
latestMessage.content &&
typeof latestMessage.content !== 'string' &&
latestMessage.content.length > 1
) {
const docFile = latestMessage.content[1]?.doc_url?.url
if (docFile) {
await executeOnMain(
NODE,
'toolRetrievalIngestNewDocument',
docFile,
data.model?.proxy_model
)
}
} else if (
// Check whether we need to ingest document or not
// Otherwise wrong context will be sent
!(await fs.existsSync(
await joinPath([
JanAssistantExtension._threadDir,
data.threadId,
'memory',
])
))
) {
// No document ingested, reroute the result to inference engine
const output = {
...data,
model: {
...data.model,
engine: data.model.proxy_model,
},
}
events.emit(MessageEvent.OnMessageSent, output)
return
}
// 2. Load agent on thread changed
if (instance.retrievalThreadId !== data.threadId) {
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)

instance.retrievalThreadId = data.threadId

// Update the text splitter
await executeOnMain(
NODE,
'toolRetrievalUpdateTextSplitter',
data.thread.assistants[0].tools[0]?.settings?.chunk_size ?? 4000,
data.thread.assistants[0].tools[0]?.settings?.chunk_overlap ?? 200
)
}

// 3. Using the retrieval template with the result and query
if (latestMessage.content) {
const prompt =
typeof latestMessage.content === 'string'
? latestMessage.content
: latestMessage.content[0].text
// Retrieve the result
const retrievalResult = await executeOnMain(
NODE,
'toolRetrievalQueryResult',
prompt
)
console.debug('toolRetrievalQueryResult', retrievalResult)

// Update message content
if (data.thread?.assistants[0]?.tools && retrievalResult)
data.messages[data.messages.length - 1].content =
data.thread.assistants[0].tools[0].settings?.retrieval_template
?.replace('{CONTEXT}', retrievalResult)
.replace('{QUESTION}', prompt)
}

// Filter out all the messages that are not text
data.messages = data.messages.map((message) => {
if (
message.content &&
typeof message.content !== 'string' &&
(message.content.length ?? 0) > 0
) {
return {
...message,
content: [message.content[0]],
}
}
return message
})

// 4. Reroute the result to inference engine
const output = {
...data,
model: {
...data.model,
engine: data.model.proxy_model,
},
}
events.emit(MessageEvent.OnMessageSent, output)
}

/**
Expand Down
Loading

0 comments on commit 8e8dfd4

Please sign in to comment.