From 8e8dfd4b3752bd13af4539ac4ca94f35e0eff811 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 25 Mar 2024 23:26:05 +0700 Subject: [PATCH] refactor: introduce inference tools (#2493) --- .../browser/extensions/engines/AIEngine.ts | 2 +- .../extensions/engines/EngineManager.ts | 7 +- .../browser/extensions/engines/OAIEngine.ts | 1 + .../browser/extensions/engines/helpers/sse.ts | 12 +- core/src/browser/index.ts | 6 + core/src/browser/tools/index.ts | 2 + core/src/browser/tools/manager.ts | 47 ++++++ core/src/browser/tools/tool.ts | 12 ++ core/src/types/model/modelEntity.ts | 5 - extensions/assistant-extension/src/index.ts | 149 +----------------- .../src/tools/retrieval.ts | 108 +++++++++++++ web/containers/Providers/EventHandler.tsx | 2 +- web/hooks/useActiveModel.ts | 6 +- web/hooks/useSendChatMessage.ts | 71 +++++---- web/services/coreService.ts | 3 +- 15 files changed, 240 insertions(+), 193 deletions(-) create mode 100644 core/src/browser/tools/index.ts create mode 100644 core/src/browser/tools/manager.ts create mode 100644 core/src/browser/tools/tool.ts create mode 100644 extensions/assistant-extension/src/tools/retrieval.ts diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 3e63e67cba..c4f8168297 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -36,7 +36,7 @@ export abstract class AIEngine extends BaseExtension { * Registers AI Engines */ registerEngine() { - EngineManager.instance()?.register(this) + EngineManager.instance().register(this) } /** diff --git a/core/src/browser/extensions/engines/EngineManager.ts b/core/src/browser/extensions/engines/EngineManager.ts index 1de31ab8da..2980c5c65e 100644 --- a/core/src/browser/extensions/engines/EngineManager.ts +++ b/core/src/browser/extensions/engines/EngineManager.ts @@ -23,7 +23,10 @@ export class EngineManager { return this.engines.get(provider) as T | undefined } - static instance(): EngineManager | undefined { - return window.core?.engineManager as EngineManager + /** + * The instance of the engine manager. + */ + static instance(): EngineManager { + return window.core?.engineManager as EngineManager ?? new EngineManager() } } diff --git a/core/src/browser/extensions/engines/OAIEngine.ts b/core/src/browser/extensions/engines/OAIEngine.ts index 53baaae2a5..41b08f4598 100644 --- a/core/src/browser/extensions/engines/OAIEngine.ts +++ b/core/src/browser/extensions/engines/OAIEngine.ts @@ -106,6 +106,7 @@ export abstract class OAIEngine extends AIEngine { return } message.status = MessageStatus.Error + message.error_code = err.code events.emit(MessageEvent.OnMessageUpdate, message) }, }) diff --git a/core/src/browser/extensions/engines/helpers/sse.ts b/core/src/browser/extensions/engines/helpers/sse.ts index 28d24ee478..def017ebc6 100644 --- a/core/src/browser/extensions/engines/helpers/sse.ts +++ b/core/src/browser/extensions/engines/helpers/sse.ts @@ -1,5 +1,5 @@ import { Observable } from 'rxjs' -import { ModelRuntimeParams } from '../../../../types' +import { ErrorCode, ModelRuntimeParams } from '../../../../types' /** * Sends a request to the inference server to generate a response based on the recent messages. * @param recentMessages - An array of recent messages to use as context for the inference. @@ -34,6 +34,16 @@ export function requestInference( signal: controller?.signal, }) .then(async (response) => { + if (!response.ok) { + const data = await response.json() + const error = { + message: data.error?.message ?? 'Error occurred.', + code: data.error?.code ?? ErrorCode.Unknown, + } + subscriber.error(error) + subscriber.complete() + return + } if (model.parameters.stream === false) { const data = await response.json() subscriber.next(data.choices[0]?.message?.content ?? '') diff --git a/core/src/browser/index.ts b/core/src/browser/index.ts index 631baf06c1..a7803c7e04 100644 --- a/core/src/browser/index.ts +++ b/core/src/browser/index.ts @@ -27,3 +27,9 @@ export * from './extension' * @module */ export * from './extensions' + +/** + * Export all base tools. + * @module + */ +export * from './tools' diff --git a/core/src/browser/tools/index.ts b/core/src/browser/tools/index.ts new file mode 100644 index 0000000000..24cd127804 --- /dev/null +++ b/core/src/browser/tools/index.ts @@ -0,0 +1,2 @@ +export * from './manager' +export * from './tool' diff --git a/core/src/browser/tools/manager.ts b/core/src/browser/tools/manager.ts new file mode 100644 index 0000000000..b323ad7ced --- /dev/null +++ b/core/src/browser/tools/manager.ts @@ -0,0 +1,47 @@ +import { AssistantTool, MessageRequest } from '../../types' +import { InferenceTool } from './tool' + +/** + * Manages the registration and retrieval of inference tools. + */ +export class ToolManager { + public tools = new Map() + + /** + * Registers a tool. + * @param tool - The tool to register. + */ + register(tool: T) { + this.tools.set(tool.name, tool) + } + + /** + * Retrieves a tool by it's name. + * @param name - The name of the tool to retrieve. + * @returns The tool, if found. + */ + get(name: string): T | undefined { + return this.tools.get(name) as T | undefined + } + + /* + ** Process the message request with the tools. + */ + process(request: MessageRequest, tools: AssistantTool[]): Promise { + return tools.reduce((prevPromise, currentTool) => { + return prevPromise.then((prevResult) => { + return currentTool.enabled + ? this.get(currentTool.type)?.process(prevResult, currentTool) ?? + Promise.resolve(prevResult) + : Promise.resolve(prevResult) + }) + }, Promise.resolve(request)) + } + + /** + * The instance of the tool manager. + */ + static instance(): ToolManager { + return (window.core?.toolManager as ToolManager) ?? new ToolManager() + } +} diff --git a/core/src/browser/tools/tool.ts b/core/src/browser/tools/tool.ts new file mode 100644 index 0000000000..0fd3429331 --- /dev/null +++ b/core/src/browser/tools/tool.ts @@ -0,0 +1,12 @@ +import { AssistantTool, MessageRequest } from '../../types' + +/** + * Represents a base inference tool. + */ +export abstract class InferenceTool { + abstract name: string + /* + ** Process a message request and return the processed message request. + */ + abstract process(request: MessageRequest, tool?: AssistantTool): Promise +} diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index d62a7c3871..a313847b69 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -7,7 +7,6 @@ export type ModelInfo = { settings: ModelSettingParams parameters: ModelRuntimeParams engine?: InferenceEngine - proxy_model?: InferenceEngine } /** @@ -21,8 +20,6 @@ export enum InferenceEngine { groq = 'groq', triton_trtllm = 'triton_trtllm', nitro_tensorrt_llm = 'nitro-tensorrt-llm', - - tool_retrieval_enabled = 'tool_retrieval_enabled', } export type ModelArtifact = { @@ -94,8 +91,6 @@ export type Model = { * The model engine. */ engine: InferenceEngine - - proxy_model?: InferenceEngine } export type ModelMetadata = { diff --git a/extensions/assistant-extension/src/index.ts b/extensions/assistant-extension/src/index.ts index 97a1cb2207..64528b0e09 100644 --- a/extensions/assistant-extension/src/index.ts +++ b/extensions/assistant-extension/src/index.ts @@ -1,26 +1,21 @@ import { fs, Assistant, - MessageRequest, events, - InferenceEngine, - MessageEvent, - InferenceEvent, joinPath, - executeOnMain, AssistantExtension, AssistantEvent, + ToolManager, } from '@janhq/core' +import { RetrievalTool } from './tools/retrieval' export default class JanAssistantExtension extends AssistantExtension { private static readonly _homeDir = 'file://assistants' - private static readonly _threadDir = 'file://threads' - - controller = new AbortController() - isCancelled = false - retrievalThreadId: string | undefined = undefined async onLoad() { + // Register the retrieval tool + ToolManager.instance().register(new RetrievalTool()) + // making the assistant directory const assistantDirExist = await fs.existsSync( JanAssistantExtension._homeDir @@ -38,140 +33,6 @@ export default class JanAssistantExtension extends AssistantExtension { // Update the assistant list events.emit(AssistantEvent.OnAssistantsUpdate, {}) } - - // Events subscription - events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => - JanAssistantExtension.handleMessageRequest(data, this) - ) - - events.on(InferenceEvent.OnInferenceStopped, () => { - JanAssistantExtension.handleInferenceStopped(this) - }) - } - - private static async handleInferenceStopped(instance: JanAssistantExtension) { - instance.isCancelled = true - instance.controller?.abort() - } - - private static async handleMessageRequest( - data: MessageRequest, - instance: JanAssistantExtension - ) { - instance.isCancelled = false - instance.controller = new AbortController() - - if ( - data.model?.engine !== InferenceEngine.tool_retrieval_enabled || - !data.messages || - // TODO: Since the engine is defined, its unsafe to assume that assistant tools are defined - // That could lead to an issue where thread stuck at generating response - !data.thread?.assistants[0]?.tools - ) { - return - } - - const latestMessage = data.messages[data.messages.length - 1] - - // 1. Ingest the document if needed - if ( - latestMessage && - latestMessage.content && - typeof latestMessage.content !== 'string' && - latestMessage.content.length > 1 - ) { - const docFile = latestMessage.content[1]?.doc_url?.url - if (docFile) { - await executeOnMain( - NODE, - 'toolRetrievalIngestNewDocument', - docFile, - data.model?.proxy_model - ) - } - } else if ( - // Check whether we need to ingest document or not - // Otherwise wrong context will be sent - !(await fs.existsSync( - await joinPath([ - JanAssistantExtension._threadDir, - data.threadId, - 'memory', - ]) - )) - ) { - // No document ingested, reroute the result to inference engine - const output = { - ...data, - model: { - ...data.model, - engine: data.model.proxy_model, - }, - } - events.emit(MessageEvent.OnMessageSent, output) - return - } - // 2. Load agent on thread changed - if (instance.retrievalThreadId !== data.threadId) { - await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId) - - instance.retrievalThreadId = data.threadId - - // Update the text splitter - await executeOnMain( - NODE, - 'toolRetrievalUpdateTextSplitter', - data.thread.assistants[0].tools[0]?.settings?.chunk_size ?? 4000, - data.thread.assistants[0].tools[0]?.settings?.chunk_overlap ?? 200 - ) - } - - // 3. Using the retrieval template with the result and query - if (latestMessage.content) { - const prompt = - typeof latestMessage.content === 'string' - ? latestMessage.content - : latestMessage.content[0].text - // Retrieve the result - const retrievalResult = await executeOnMain( - NODE, - 'toolRetrievalQueryResult', - prompt - ) - console.debug('toolRetrievalQueryResult', retrievalResult) - - // Update message content - if (data.thread?.assistants[0]?.tools && retrievalResult) - data.messages[data.messages.length - 1].content = - data.thread.assistants[0].tools[0].settings?.retrieval_template - ?.replace('{CONTEXT}', retrievalResult) - .replace('{QUESTION}', prompt) - } - - // Filter out all the messages that are not text - data.messages = data.messages.map((message) => { - if ( - message.content && - typeof message.content !== 'string' && - (message.content.length ?? 0) > 0 - ) { - return { - ...message, - content: [message.content[0]], - } - } - return message - }) - - // 4. Reroute the result to inference engine - const output = { - ...data, - model: { - ...data.model, - engine: data.model.proxy_model, - }, - } - events.emit(MessageEvent.OnMessageSent, output) } /** diff --git a/extensions/assistant-extension/src/tools/retrieval.ts b/extensions/assistant-extension/src/tools/retrieval.ts new file mode 100644 index 0000000000..35738fd8e0 --- /dev/null +++ b/extensions/assistant-extension/src/tools/retrieval.ts @@ -0,0 +1,108 @@ +import { + AssistantTool, + executeOnMain, + fs, + InferenceTool, + joinPath, + MessageRequest, +} from '@janhq/core' + +export class RetrievalTool extends InferenceTool { + private _threadDir = 'file://threads' + private retrievalThreadId: string | undefined = undefined + + name: string = 'retrieval' + + async process( + data: MessageRequest, + tool?: AssistantTool + ): Promise { + if (!data.model || !data.messages) { + return Promise.resolve(data) + } + + const latestMessage = data.messages[data.messages.length - 1] + + // 1. Ingest the document if needed + if ( + latestMessage && + latestMessage.content && + typeof latestMessage.content !== 'string' && + latestMessage.content.length > 1 + ) { + const docFile = latestMessage.content[1]?.doc_url?.url + if (docFile) { + await executeOnMain( + NODE, + 'toolRetrievalIngestNewDocument', + docFile, + data.model?.engine + ) + } + } else if ( + // Check whether we need to ingest document or not + // Otherwise wrong context will be sent + !(await fs.existsSync( + await joinPath([this._threadDir, data.threadId, 'memory']) + )) + ) { + // No document ingested, reroute the result to inference engine + + return Promise.resolve(data) + } + // 2. Load agent on thread changed + if (this.retrievalThreadId !== data.threadId) { + await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId) + + this.retrievalThreadId = data.threadId + + // Update the text splitter + await executeOnMain( + NODE, + 'toolRetrievalUpdateTextSplitter', + tool?.settings?.chunk_size ?? 4000, + tool?.settings?.chunk_overlap ?? 200 + ) + } + + // 3. Using the retrieval template with the result and query + if (latestMessage.content) { + const prompt = + typeof latestMessage.content === 'string' + ? latestMessage.content + : latestMessage.content[0].text + // Retrieve the result + const retrievalResult = await executeOnMain( + NODE, + 'toolRetrievalQueryResult', + prompt + ) + console.debug('toolRetrievalQueryResult', retrievalResult) + + // Update message content + if (retrievalResult) + data.messages[data.messages.length - 1].content = + tool?.settings?.retrieval_template + ?.replace('{CONTEXT}', retrievalResult) + .replace('{QUESTION}', prompt) + } + + // Filter out all the messages that are not text + data.messages = data.messages.map((message) => { + if ( + message.content && + typeof message.content !== 'string' && + (message.content.length ?? 0) > 0 + ) { + return { + ...message, + content: [message.content[0]], + } + } + return message + }) + + // 4. Reroute the result to inference engine + return Promise.resolve(data) + } +} diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index d62cb3f8fc..4d5555a469 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -230,7 +230,7 @@ export default function EventHandler({ children }: { children: ReactNode }) { // 2. Update the title with the result of the inference setTimeout(() => { - const engine = EngineManager.instance()?.get( + const engine = EngineManager.instance().get( messageRequest.model?.engine ?? activeModelRef.current?.engine ?? '' ) engine?.inference(messageRequest) diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index e7cd4888d2..0da28efe4f 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -78,7 +78,7 @@ export function useActiveModel() { } localStorage.setItem(LAST_USED_MODEL_ID, model.id) - const engine = EngineManager.instance()?.get(model.engine) + const engine = EngineManager.instance().get(model.engine) return engine ?.loadModel(model) .then(() => { @@ -95,7 +95,6 @@ export function useActiveModel() { }) }) .catch((error) => { - console.error('Failed to load model: ', error) setStateModel(() => ({ state: 'start', loading: false, @@ -108,13 +107,14 @@ export function useActiveModel() { type: 'success', }) setLoadModelError(error) + return Promise.reject(error) }) } const stopModel = useCallback(async () => { if (activeModel) { setStateModel({ state: 'stop', loading: true, model: activeModel.id }) - const engine = EngineManager.instance()?.get(activeModel.engine) + const engine = EngineManager.instance().get(activeModel.engine) await engine ?.unloadModel(activeModel) .catch() diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index d316cdbeec..b380320091 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -9,9 +9,8 @@ import { ThreadMessage, Model, ConversationalExtension, - InferenceEngine, - AssistantTool, EngineManager, + ToolManager, } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' @@ -111,7 +110,10 @@ export default function useSendChatMessage() { activeThreadRef.current.assistants[0].model.id if (modelRef.current?.id !== modelId) { - await startModel(modelId) + const error = await startModel(modelId).catch((error: Error) => error) + if (error) { + return + } } setIsGeneratingResponse(true) @@ -128,10 +130,18 @@ export default function useSendChatMessage() { ) } } - const engine = EngineManager.instance()?.get( - requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? '' + // Process message request with Assistants tools + const request = await ToolManager.instance().process( + requestBuilder.build(), + activeThreadRef.current.assistants?.flatMap( + (assistant) => assistant.tools ?? [] + ) ?? [] ) - engine?.inference(requestBuilder.build()) + + const engine = + requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? '' + + EngineManager.instance().get(engine)?.inference(request) } // Define interface extending Array prototype @@ -149,8 +159,9 @@ export default function useSendChatMessage() { const runtimeParams = toRuntimeParams(activeModelParams) const settingParams = toSettingParams(activeModelParams) - updateThreadWaiting(activeThreadRef.current.id, true) const prompt = message.trim() + + updateThreadWaiting(activeThreadRef.current.id, true) setCurrentPrompt('') setEditPrompt('') @@ -158,17 +169,12 @@ export default function useSendChatMessage() { ? await getBase64(fileUpload[0].file) : undefined - const fileContentType = fileUpload[0]?.type - - const isDocumentInput = base64Blob && fileContentType === 'pdf' - const isImageInput = base64Blob && fileContentType === 'image' - - if (isImageInput && base64Blob) { + if (base64Blob && fileUpload[0]?.type === 'image') { // Compress image base64Blob = await compressImage(base64Blob, 512) } - let modelRequest = + const modelRequest = selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model // Fallback support for previous broken threads @@ -182,23 +188,6 @@ export default function useSendChatMessage() { if (runtimeParams.stream == null) { runtimeParams.stream = true } - // Add middleware to the model request with tool retrieval enabled - if ( - activeThreadRef.current.assistants[0].tools?.some( - (tool: AssistantTool) => tool.type === 'retrieval' && tool.enabled - ) - ) { - modelRequest = { - ...modelRequest, - // Tool retrieval support document input only for now - ...(isDocumentInput - ? { - engine: InferenceEngine.tool_retrieval_enabled, - proxy_model: modelRequest.engine, - } - : {}), - } - } // Build Message Request const requestBuilder = new MessageRequestBuilder( @@ -247,15 +236,27 @@ export default function useSendChatMessage() { if (modelRef.current?.id !== modelId) { setQueuedMessage(true) - await startModel(modelId) + const error = await startModel(modelId).catch((error: Error) => error) setQueuedMessage(false) + if (error) { + updateThreadWaiting(activeThreadRef.current.id, false) + return + } } setIsGeneratingResponse(true) - const engine = EngineManager.instance()?.get( - requestBuilder.model?.engine ?? modelRequest.engine ?? '' + // Process message request with Assistants tools + const request = await ToolManager.instance().process( + requestBuilder.build(), + activeThreadRef.current.assistants?.flatMap( + (assistant) => assistant.tools ?? [] + ) ?? [] ) - engine?.inference(requestBuilder.build()) + + // Request for inference + EngineManager.instance() + .get(requestBuilder.model?.engine ?? modelRequest.engine ?? '') + ?.inference(request) // Reset states setReloadModel(false) diff --git a/web/services/coreService.ts b/web/services/coreService.ts index aa76a9c1a0..aeb1cca1a9 100644 --- a/web/services/coreService.ts +++ b/web/services/coreService.ts @@ -1,4 +1,4 @@ -import { EngineManager } from '@janhq/core' +import { EngineManager, ToolManager } from '@janhq/core' import { appService } from './appService' import { EventEmitter } from './eventsService' @@ -15,6 +15,7 @@ export const setupCoreServices = () => { window.core = { events: new EventEmitter(), engineManager: new EngineManager(), + toolManager: new ToolManager(), api: { ...(window.electronAPI ? window.electronAPI : restAPI), ...appService,