Skip to content

Commit

Permalink
refactor: introduce message request builder (janhq#2481)
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan committed Mar 25, 2024
1 parent 9551996 commit 77cbdc2
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 154 deletions.
197 changes: 43 additions & 154 deletions web/hooks/useSendChatMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,19 @@
import { useEffect, useRef } from 'react'

import {
ChatCompletionMessage,
ChatCompletionRole,
ContentType,
MessageRequest,
MessageRequestType,
MessageStatus,
ExtensionTypeEnum,
Thread,
ThreadMessage,
Model,
ConversationalExtension,
InferenceEngine,
ChatCompletionMessageContentType,
AssistantTool,
EngineManager,
} from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'

import { ulid } from 'ulidx'

import { selectedModelAtom } from '@/containers/DropdownListSidebar'
import {
currentPromptAtom,
Expand All @@ -30,8 +23,11 @@ import {
} from '@/containers/Providers/Jotai'

import { compressImage, getBase64 } from '@/utils/base64'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'

import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'

import { loadModelErrorAtom, useActiveModel } from './useActiveModel'

import { extensionManager } from '@/extension/ExtensionManager'
Expand Down Expand Up @@ -102,39 +98,13 @@ export default function useSendChatMessage() {
return
}
updateThreadWaiting(activeThreadRef.current.id, true)
const messages: ChatCompletionMessage[] = [
activeThreadRef.current.assistants[0]?.instructions,
]
.filter((e) => e && e.trim() !== '')
.map<ChatCompletionMessage>((instructions) => {
const systemMessage: ChatCompletionMessage = {
role: ChatCompletionRole.System,
content: instructions,
}
return systemMessage
})
.concat(
currentMessages
.filter(
(e) =>
(currentMessage.role === ChatCompletionRole.User ||
e.id !== currentMessage.id) &&
e.status !== MessageStatus.Error
)
.map<ChatCompletionMessage>((msg) => ({
role: msg.role,
content: msg.content[0]?.text.value ?? '',
}))
)

const messageRequest: MessageRequest = {
id: ulid(),
type: MessageRequestType.Thread,
messages: messages,
threadId: activeThreadRef.current.id,
model:
activeThreadRef.current.assistants[0].model ?? selectedModelRef.current,
}
const requestBuilder = new MessageRequestBuilder(
MessageRequestType.Thread,
activeThreadRef.current.assistants[0].model ?? selectedModelRef.current,
activeThreadRef.current,
currentMessages
).addSystemMessage(activeThreadRef.current.assistants[0]?.instructions)

const modelId =
selectedModelRef.current?.id ??
Expand All @@ -143,7 +113,9 @@ export default function useSendChatMessage() {
if (modelRef.current?.id !== modelId) {
await startModel(modelId)
}

setIsGeneratingResponse(true)

if (currentMessage.role !== ChatCompletionRole.User) {
// Delete last response before regenerating
deleteMessage(currentMessage.id ?? '')
Expand All @@ -157,11 +129,13 @@ export default function useSendChatMessage() {
}
}
const engine = EngineManager.instance()?.get(
messageRequest.model?.engine ?? selectedModelRef.current?.engine ?? ''
requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? ''
)
engine?.inference(messageRequest)
engine?.inference(requestBuilder.build())
}

// Define interface extending Array prototype

const sendChatMessage = async (message: string) => {
if (!message || message.trim().length === 0) return

Expand All @@ -186,8 +160,6 @@ export default function useSendChatMessage() {

const fileContentType = fileUpload[0]?.type

const msgId = ulid()

const isDocumentInput = base64Blob && fileContentType === 'pdf'
const isImageInput = base64Blob && fileContentType === 'image'

Expand All @@ -196,56 +168,6 @@ export default function useSendChatMessage() {
base64Blob = await compressImage(base64Blob, 512)
}

const messages: ChatCompletionMessage[] = [
activeThreadRef.current.assistants[0]?.instructions,
]
.filter((e) => e && e.trim() !== '')
.map<ChatCompletionMessage>((instructions) => {
const systemMessage: ChatCompletionMessage = {
role: ChatCompletionRole.System,
content: instructions,
}
return systemMessage
})
.concat(
currentMessages
.filter((e) => e.status !== MessageStatus.Error)
.map<ChatCompletionMessage>((msg) => ({
role: msg.role,
content: msg.content[0]?.text.value ?? '',
}))
.concat([
{
role: ChatCompletionRole.User,
content:
selectedModelRef.current && base64Blob
? [
{
type: ChatCompletionMessageContentType.Text,
text: prompt,
},
isDocumentInput
? {
type: ChatCompletionMessageContentType.Doc,
doc_url: {
url: `threads/${activeThreadRef.current.id}/files/${msgId}.pdf`,
},
}
: null,
isImageInput
? {
type: ChatCompletionMessageContentType.Image,
image_url: {
url: base64Blob,
},
}
: null,
].filter((e) => e !== null)
: prompt,
} as ChatCompletionMessage,
])
)

let modelRequest =
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model

Expand Down Expand Up @@ -277,86 +199,48 @@ export default function useSendChatMessage() {
: {}),
}
}
const messageRequest: MessageRequest = {
id: msgId,
type: MessageRequestType.Thread,
threadId: activeThreadRef.current.id,
messages,
model: {

// Build Message Request
const requestBuilder = new MessageRequestBuilder(
MessageRequestType.Thread,
{
...modelRequest,
settings: settingParams,
parameters: runtimeParams,
},
thread: activeThreadRef.current,
}
activeThreadRef.current,
currentMessages
).addSystemMessage(activeThreadRef.current.assistants[0].instructions)

const timestamp = Date.now()
const content: any = []

if (base64Blob && fileUpload[0]?.type === 'image') {
content.push({
type: ContentType.Image,
text: {
value: prompt,
annotations: [base64Blob],
},
})
}
requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type)

if (base64Blob && fileUpload[0]?.type === 'pdf') {
content.push({
type: ContentType.Pdf,
text: {
value: prompt,
annotations: [base64Blob],
name: fileUpload[0].file.name,
size: fileUpload[0].file.size,
},
})
}
// Build Thread Message to persist
const threadMessageBuilder = new ThreadMessageBuilder(
requestBuilder
).pushMessage(prompt, base64Blob, fileUpload)

if (prompt && !base64Blob) {
content.push({
type: ContentType.Text,
text: {
value: prompt,
annotations: [],
},
})
}
const newMessage = threadMessageBuilder.build()

const threadMessage: ThreadMessage = {
id: msgId,
thread_id: activeThreadRef.current.id,
role: ChatCompletionRole.User,
status: MessageStatus.Ready,
created: timestamp,
updated: timestamp,
object: 'thread.message',
content: content,
}

addNewMessage(threadMessage)
if (base64Blob) {
setFileUpload([])
}
// Push to states
addNewMessage(newMessage)

// Update thread state
const updatedThread: Thread = {
...activeThreadRef.current,
updated: timestamp,
updated: newMessage.created,
metadata: {
...(activeThreadRef.current.metadata ?? {}),
lastMessage: prompt,
},
}

// change last update thread when send message
updateThread(updatedThread)

// Add message
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.addNewMessage(threadMessage)
?.addNewMessage(newMessage)

// Start Model if not started
const modelId =
selectedModelRef.current?.id ??
activeThreadRef.current.assistants[0].model.id
Expand All @@ -369,12 +253,17 @@ export default function useSendChatMessage() {
setIsGeneratingResponse(true)

const engine = EngineManager.instance()?.get(
messageRequest.model?.engine ?? modelRequest.engine ?? ''
requestBuilder.model?.engine ?? modelRequest.engine ?? ''
)
engine?.inference(messageRequest)
engine?.inference(requestBuilder.build())

// Reset states
setReloadModel(false)
setEngineParamsUpdate(false)

if (base64Blob) {
setFileUpload([])
}
}

return {
Expand Down
Loading

0 comments on commit 77cbdc2

Please sign in to comment.