Skip to content

Commit

Permalink
fix: the selected model auto revert back to previous used model with …
Browse files Browse the repository at this point in the history
…setting mismatch (janhq#1883)

* fix: the selected model auto revert back to previous used model with setting mismatch

* fix: view in finder and view file action
  • Loading branch information
louis-jan authored Feb 1, 2024
1 parent 4116aaa commit 5ddc6ea
Show file tree
Hide file tree
Showing 15 changed files with 132 additions and 262 deletions.
1 change: 0 additions & 1 deletion core/src/types/thread/threadEntity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,4 @@ export type ThreadState = {
waitingForResponse: boolean
error?: Error
lastMessage?: string
isFinishInit?: boolean
}
2 changes: 1 addition & 1 deletion extensions/inference-nitro-extension/bin/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.3
0.3.5
22 changes: 11 additions & 11 deletions extensions/inference-nitro-extension/src/node/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ function stopModel(): Promise<void> {
* TODO: Should pass absolute of the model file instead of just the name - So we can modurize the module.ts to npm package
*/
async function runModel(
wrapper: ModelInitOptions
wrapper: ModelInitOptions,
): Promise<ModelOperationResponse | void> {
if (wrapper.model.engine !== InferenceEngine.nitro) {
// Not a nitro model
Expand All @@ -94,7 +94,7 @@ async function runModel(
const ggufBinFile = files.find(
(file) =>
file === path.basename(currentModelFile) ||
file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT)
file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT),
);

if (!ggufBinFile) return Promise.reject("No GGUF model file found");
Expand Down Expand Up @@ -189,10 +189,10 @@ function promptTemplateConverter(promptTemplate: string): PromptTemplate {
const system_prompt = promptTemplate.substring(0, systemIndex);
const user_prompt = promptTemplate.substring(
systemIndex + systemMarker.length,
promptIndex
promptIndex,
);
const ai_prompt = promptTemplate.substring(
promptIndex + promptMarker.length
promptIndex + promptMarker.length,
);

// Return the split parts
Expand All @@ -202,7 +202,7 @@ function promptTemplateConverter(promptTemplate: string): PromptTemplate {
const promptIndex = promptTemplate.indexOf(promptMarker);
const user_prompt = promptTemplate.substring(0, promptIndex);
const ai_prompt = promptTemplate.substring(
promptIndex + promptMarker.length
promptIndex + promptMarker.length,
);

// Return the split parts
Expand Down Expand Up @@ -231,8 +231,8 @@ function loadLLMModel(settings: any): Promise<Response> {
.then((res) => {
log(
`[NITRO]::Debug: Load model success with response ${JSON.stringify(
res
)}`
res,
)}`,
);
return Promise.resolve(res);
})
Expand Down Expand Up @@ -261,8 +261,8 @@ async function validateModelStatus(): Promise<void> {
}).then(async (res: Response) => {
log(
`[NITRO]::Debug: Validate model state success with response ${JSON.stringify(
res
)}`
res,
)}`,
);
// If the response is OK, check model_loaded status.
if (res.ok) {
Expand Down Expand Up @@ -313,7 +313,7 @@ function spawnNitroProcess(): Promise<any> {
const args: string[] = ["1", LOCAL_HOST, PORT.toString()];
// Execute the binary
log(
`[NITRO]::Debug: Spawn nitro at path: ${executableOptions.executablePath}, and args: ${args}`
`[NITRO]::Debug: Spawn nitro at path: ${executableOptions.executablePath}, and args: ${args}`,
);
subprocess = spawn(
executableOptions.executablePath,
Expand All @@ -324,7 +324,7 @@ function spawnNitroProcess(): Promise<any> {
...process.env,
CUDA_VISIBLE_DEVICES: executableOptions.cudaVisibleDevices,
},
}
},
);

// Handle subprocess output
Expand Down
51 changes: 14 additions & 37 deletions web/containers/DropdownListSidebar/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import { useMainViewState } from '@/hooks/useMainViewState'

import useRecommendedModel from '@/hooks/useRecommendedModel'

import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'

import { toGibibytes } from '@/utils/converter'

import ModelLabel from '../ModelLabel'
Expand All @@ -34,10 +36,8 @@ import OpenAiKeyInput from '../OpenAiKeyInput'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'

import {
ModelParams,
activeThreadAtom,
setThreadModelParamsAtom,
threadStatesAtom,
} from '@/helpers/atoms/Thread.atom'

export const selectedModelAtom = atom<Model | undefined>(undefined)
Expand All @@ -49,7 +49,6 @@ const DropdownListSidebar = ({
strictedThread?: boolean
}) => {
const activeThread = useAtomValue(activeThreadAtom)
const threadStates = useAtomValue(threadStatesAtom)
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)

Expand All @@ -58,15 +57,7 @@ const DropdownListSidebar = ({
const { setMainViewState } = useMainViewState()
const [loader, setLoader] = useState(0)
const { recommendedModel, downloadedModels } = useRecommendedModel()

/**
* Default value for max_tokens and ctx_len
* Its to avoid OOM issue since a model can set a big number for these settings
*/
const defaultValue = (value?: number) => {
if (value && value < 4096) return value
return 4096
}
const { updateModelParameter } = useUpdateModelParameters()

useEffect(() => {
if (!activeThread) return
Expand All @@ -78,31 +69,7 @@ const DropdownListSidebar = ({
model = recommendedModel
}
setSelectedModel(model)
const finishInit = threadStates[activeThread.id].isFinishInit ?? true
if (finishInit) return
const modelParams: ModelParams = {
...model?.parameters,
...model?.settings,
/**
* This is to set default value for these settings instead of maximum value
* Should only apply when model.json has these settings
*/
...(model?.parameters.max_tokens && {
max_tokens: defaultValue(model?.parameters.max_tokens),
}),
...(model?.settings.ctx_len && {
ctx_len: defaultValue(model?.settings.ctx_len),
}),
}
setThreadModelParams(activeThread.id, modelParams)
}, [
recommendedModel,
activeThread,
threadStates,
downloadedModels,
setThreadModelParams,
setSelectedModel,
])
}, [recommendedModel, activeThread, downloadedModels, setSelectedModel])

// This is fake loader please fix this when we have realtime percentage when load model
useEffect(() => {
Expand Down Expand Up @@ -144,7 +111,16 @@ const DropdownListSidebar = ({
...model?.parameters,
...model?.settings,
}
// Update model paramter to the thread state
setThreadModelParams(activeThread.id, modelParams)

// Update model parameter to the thread file
if (model)
updateModelParameter(activeThread.id, {
params: modelParams,
modelId: model.id,
engine: model.engine,
})
}
},
[
Expand All @@ -154,6 +130,7 @@ const DropdownListSidebar = ({
setSelectedModel,
setServerEnabled,
setThreadModelParams,
updateModelParameter,
]
)

Expand Down
12 changes: 0 additions & 12 deletions web/helpers/atoms/Thread.atom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,6 @@ export const deleteThreadStateAtom = atom(
}
)

export const updateThreadInitSuccessAtom = atom(
null,
(get, set, threadId: string) => {
const currentState = { ...get(threadStatesAtom) }
currentState[threadId] = {
...currentState[threadId],
isFinishInit: true,
}
set(threadStatesAtom, currentState)
}
)

export const updateThreadWaitingForResponseAtom = atom(
null,
(get, set, threadId: string, waitingForResponse: boolean) => {
Expand Down
64 changes: 29 additions & 35 deletions web/hooks/useCreateNewThread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@ import {
} from '@janhq/core'
import { atom, useAtomValue, useSetAtom } from 'jotai'

import { selectedModelAtom } from '@/containers/DropdownListSidebar'
import { fileUploadAtom } from '@/containers/Providers/Jotai'

import { generateThreadId } from '@/utils/thread'

import useDeleteThread from './useDeleteThread'
import useRecommendedModel from './useRecommendedModel'

import useSetActiveThread from './useSetActiveThread'

import { extensionManager } from '@/extension'
import {
threadsAtom,
setActiveThreadIdAtom,
threadStatesAtom,
updateThreadAtom,
updateThreadInitSuccessAtom,
setThreadModelParamsAtom,
} from '@/helpers/atoms/Thread.atom'

const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
Expand All @@ -32,7 +34,6 @@ const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
hasMore: false,
waitingForResponse: false,
lastMessage: undefined,
isFinishInit: false,
}
currentState[newThread.id] = threadState
set(threadStatesAtom, currentState)
Expand All @@ -43,47 +44,35 @@ const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
})

export const useCreateNewThread = () => {
const threadStates = useAtomValue(threadStatesAtom)
const updateThreadFinishInit = useSetAtom(updateThreadInitSuccessAtom)
const createNewThread = useSetAtom(createNewThreadAtom)
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
const { setActiveThread } = useSetActiveThread()
const updateThread = useSetAtom(updateThreadAtom)

const setFileUpload = useSetAtom(fileUploadAtom)
const { deleteThread } = useDeleteThread()
const setSelectedModel = useSetAtom(selectedModelAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)

const { recommendedModel, downloadedModels } = useRecommendedModel()

const requestCreateNewThread = async (
assistant: Assistant,
model?: Model | undefined
) => {
// loop through threads state and filter if there's any thread that is not finish init
let unfinishedInitThreadId: string | undefined = undefined
for (const key in threadStates) {
const isFinishInit = threadStates[key].isFinishInit ?? true
if (!isFinishInit) {
unfinishedInitThreadId = key
break
}
}

if (unfinishedInitThreadId) {
await deleteThread(unfinishedInitThreadId)
}
const defaultModel = model ?? recommendedModel ?? downloadedModels[0]

const modelId = model ? model.id : '*'
const createdAt = Date.now()
const assistantInfo: ThreadAssistantInfo = {
assistant_id: assistant.id,
assistant_name: assistant.name,
tools: assistant.tools,
model: {
id: modelId,
settings: {},
parameters: {},
engine: undefined,
id: defaultModel?.id ?? '*',
settings: defaultModel?.settings ?? {},
parameters: defaultModel?.parameters ?? {},
engine: defaultModel?.engine,
},
instructions: assistant.instructions,
}

const threadId = generateThreadId(assistant.id)
const thread: Thread = {
id: threadId,
Expand All @@ -95,22 +84,27 @@ export const useCreateNewThread = () => {
}

// add the new thread on top of the thread list to the state
//TODO: Why do we have thread list then thread states? Should combine them
createNewThread(thread)
setActiveThreadId(thread.id)

setSelectedModel(defaultModel)
setThreadModelParams(thread.id, {
...defaultModel?.settings,
...defaultModel?.parameters,
})

// Delete the file upload state
setFileUpload([])
// Update thread metadata
await updateThreadMetadata(thread)

setActiveThread(thread)
}

function updateThreadMetadata(thread: Thread) {
async function updateThreadMetadata(thread: Thread) {
updateThread(thread)
const threadState = threadStates[thread.id]
const isFinishInit = threadState?.isFinishInit ?? true
if (!isFinishInit) {
updateThreadFinishInit(thread.id)
}

extensionManager
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread(thread)
}
Expand Down
22 changes: 7 additions & 15 deletions web/hooks/useDeleteThread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import {
threadsAtom,
setActiveThreadIdAtom,
deleteThreadStateAtom,
threadStatesAtom,
updateThreadStateLastMessageAtom,
} from '@/helpers/atoms/Thread.atom'

Expand All @@ -34,7 +33,6 @@ export default function useDeleteThread() {
const deleteMessages = useSetAtom(deleteChatMessagesAtom)
const cleanMessages = useSetAtom(cleanChatMessagesAtom)
const deleteThreadState = useSetAtom(deleteThreadStateAtom)
const threadStates = useAtomValue(threadStatesAtom)
const updateThreadLastMessage = useSetAtom(updateThreadStateLastMessageAtom)

const cleanThread = async (threadId: string) => {
Expand Down Expand Up @@ -74,22 +72,16 @@ export default function useDeleteThread() {
const availableThreads = threads.filter((c) => c.id !== threadId)
setThreads(availableThreads)

const deletingThreadState = threadStates[threadId]
const isFinishInit = deletingThreadState?.isFinishInit ?? true

// delete the thread state
deleteThreadState(threadId)

if (isFinishInit) {
deleteMessages(threadId)
setCurrentPrompt('')
toaster({
title: 'Thread successfully deleted.',
description: `Thread ${threadId} has been successfully deleted.`,
type: 'success',
})
}

deleteMessages(threadId)
setCurrentPrompt('')
toaster({
title: 'Thread successfully deleted.',
description: `Thread ${threadId} has been successfully deleted.`,
type: 'success',
})
if (availableThreads.length > 0) {
setActiveThreadId(availableThreads[0].id)
} else {
Expand Down
Loading

0 comments on commit 5ddc6ea

Please sign in to comment.