Skip to content

Commit

Permalink
feat: Thread titles should auto-summarize Topic (janhq#1976)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xgokuz authored Feb 10, 2024
1 parent 5864f49 commit 875c2bc
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 7 deletions.
1 change: 1 addition & 0 deletions core/src/types/message/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from './messageEntity'
export * from './messageInterface'
export * from './messageEvent'
export * from './messageRequestType'
4 changes: 4 additions & 0 deletions core/src/types/message/messageEntity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ export type ThreadMessage = {
updated: number
/** The additional metadata of this message. **/
metadata?: Record<string, unknown>

type?: string
}

/**
Expand Down Expand Up @@ -56,6 +58,8 @@ export type MessageRequest = {
/** The thread of this message is belong to. **/
// TODO: deprecate threadId field
thread?: Thread

type?: string
}

/**
Expand Down
5 changes: 5 additions & 0 deletions core/src/types/message/messageRequestType.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export enum MessageRequestType {
Thread = 'Thread',
Assistant = 'Assistant',
Summary = 'Summary',
}
7 changes: 6 additions & 1 deletion extensions/inference-nitro-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
ChatCompletionRole,
ContentType,
MessageRequest,
MessageRequestType,
MessageStatus,
ThreadContent,
ThreadMessage,
Expand Down Expand Up @@ -250,6 +251,7 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
Expand All @@ -258,7 +260,10 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
updated: timestamp,
object: "thread.message",
};
events.emit(MessageEvent.OnMessageResponse, message);

if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message);
}

this.isCancelled = false;
this.controller = new AbortController();
Expand Down
7 changes: 6 additions & 1 deletion extensions/inference-openai-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
InferenceEngine,
BaseExtension,
MessageEvent,
MessageRequestType,
ModelEvent,
InferenceEvent,
AppConfigurationEventName,
Expand Down Expand Up @@ -157,6 +158,7 @@ export default class JanInferenceOpenAIExtension extends BaseExtension {
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
Expand All @@ -165,7 +167,10 @@ export default class JanInferenceOpenAIExtension extends BaseExtension {
updated: timestamp,
object: "thread.message",
};
events.emit(MessageEvent.OnMessageResponse, message);

if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message);
}

instance.isCancelled = false;
instance.controller = new AbortController();
Expand Down
68 changes: 64 additions & 4 deletions web/containers/Providers/EventHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
import { ReactNode, useCallback, useEffect, useRef } from 'react'

import {
ChatCompletionMessage,
ChatCompletionRole,
events,
ThreadMessage,
ExtensionTypeEnum,
MessageStatus,
MessageRequest,
Model,
ConversationalExtension,
MessageEvent,
MessageRequestType,
ModelEvent,
} from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulid'

import {
activeModelAtom,
Expand All @@ -25,6 +30,7 @@ import { toaster } from '../Toast'

import { extensionManager } from '@/extension'
import {
getCurrentChatMessagesAtom,
addNewMessageAtom,
updateMessageAtom,
} from '@/helpers/atoms/ChatMessage.atom'
Expand All @@ -37,9 +43,11 @@ import {
} from '@/helpers/atoms/Thread.atom'

export default function EventHandler({ children }: { children: ReactNode }) {
const messages = useAtomValue(getCurrentChatMessagesAtom)
const addNewMessage = useSetAtom(addNewMessageAtom)
const updateMessage = useSetAtom(updateMessageAtom)
const downloadedModels = useAtomValue(downloadedModelsAtom)
const activeModel = useAtomValue(activeModelAtom)
const setActiveModel = useSetAtom(activeModelAtom)
const setStateModel = useSetAtom(stateModelAtom)
const setQueuedMessage = useSetAtom(queuedMessageAtom)
Expand All @@ -51,6 +59,8 @@ export default function EventHandler({ children }: { children: ReactNode }) {
const threadsRef = useRef(threads)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const updateThread = useSetAtom(updateThreadAtom)
const messagesRef = useRef(messages)
const activeModelRef = useRef(activeModel)

useEffect(() => {
threadsRef.current = threads
Expand All @@ -60,9 +70,51 @@ export default function EventHandler({ children }: { children: ReactNode }) {
modelsRef.current = downloadedModels
}, [downloadedModels])

useEffect(() => {
messagesRef.current = messages
}, [messages])

useEffect(() => {
activeModelRef.current = activeModel
}, [activeModel])

const onNewMessageResponse = useCallback(
(message: ThreadMessage) => {
addNewMessage(message)
const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
// If this is the first ever prompt in the thread
if (thread && thread.title.trim() == 'New Thread') {
// This is the first time message comes in on a new thread
// Summarize the first message, and make that the title of the Thread
// 1. Get the summary of the first prompt using whatever engine user is currently using
const firstPrompt = messagesRef?.current[0].content[0].text.value.trim()
const summarizeFirstPrompt =
'Summarize "' + firstPrompt + '" in 5 words as a title'

// Prompt: Given this query from user {query}, return to me the summary in 5 words as the title
const msgId = ulid()
const messages: ChatCompletionMessage[] = [
{
role: ChatCompletionRole.User,
content: summarizeFirstPrompt,
} as ChatCompletionMessage,
]

const firstPromptRequest: MessageRequest = {
id: msgId,
threadId: message.thread_id,
type: MessageRequestType.Summary,
messages,
model: activeModelRef?.current,
}

// 2. Update the title with the result of the inference
// the title will be updated as part of the `EventName.OnFirstPromptUpdate`
events.emit(MessageEvent.OnMessageSent, firstPromptRequest)
}

if (message.type !== MessageRequestType.Summary) {
addNewMessage(message)
}
},
[addNewMessage]
)
Expand Down Expand Up @@ -134,6 +186,11 @@ export default function EventHandler({ children }: { children: ReactNode }) {
...(messageContent && { lastMessage: messageContent }),
}

// Update the Thread title with the response of the inference on the 1st prompt
if (message.type === MessageRequestType.Summary) {
thread.title = messageContent
}

updateThread({
...thread,
metadata,
Expand All @@ -146,9 +203,12 @@ export default function EventHandler({ children }: { children: ReactNode }) {
metadata,
})

extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.addNewMessage(message)
// If this is not the summary of the Thread, don't need to add it to the Thread
if (message.type !== MessageRequestType.Summary) {
extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.addNewMessage(message)
}
}
},
[updateMessage, updateThreadWaiting, setIsGeneratingResponse, updateThread]
Expand Down
5 changes: 4 additions & 1 deletion web/hooks/useSendChatMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
ChatCompletionRole,
ContentType,
MessageRequest,
MessageRequestType,
MessageStatus,
ExtensionTypeEnum,
Thread,
Expand Down Expand Up @@ -112,6 +113,7 @@ export default function useSendChatMessage() {

const messageRequest: MessageRequest = {
id: ulid(),
type: MessageRequestType.Thread,
messages: messages,
threadId: activeThread.id,
model: activeThread.assistants[0].model ?? selectedModel,
Expand Down Expand Up @@ -209,6 +211,7 @@ export default function useSendChatMessage() {
}
const messageRequest: MessageRequest = {
id: msgId,
type: MessageRequestType.Thread,
threadId: activeThread.id,
messages,
model: {
Expand All @@ -218,8 +221,8 @@ export default function useSendChatMessage() {
},
thread: activeThread,
}
const timestamp = Date.now()

const timestamp = Date.now()
const content: any = []

if (base64Blob && fileUpload[0]?.type === 'image') {
Expand Down

0 comments on commit 875c2bc

Please sign in to comment.