Skip to content

Commit

Permalink
fix: cancel loading model with stop action (janhq#2607)
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan authored Apr 4, 2024
1 parent 7f92a5a commit 1eaf13b
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 45 deletions.
1 change: 1 addition & 0 deletions extensions/inference-nitro-extension/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"path-browserify": "^1.0.1",
"rxjs": "^7.8.1",
"tcp-port-used": "^1.0.2",
"terminate": "^2.6.1",
"ulidx": "^2.3.0"
},
"engines": {
Expand Down
51 changes: 36 additions & 15 deletions extensions/inference-nitro-extension/src/node/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
SystemInformation,
} from '@janhq/core/node'
import { executableNitroFile } from './execute'
import terminate from 'terminate'

// Polyfill fetch with retry
const fetchRetry = fetchRT(fetch)
Expand Down Expand Up @@ -304,23 +305,43 @@ async function killSubprocess(): Promise<void> {
setTimeout(() => controller.abort(), 5000)
log(`[NITRO]::Debug: Request to kill Nitro`)

return fetch(NITRO_HTTP_KILL_URL, {
method: 'DELETE',
signal: controller.signal,
})
.then(() => {
subprocess?.kill()
subprocess = undefined
const killRequest = () => {
return fetch(NITRO_HTTP_KILL_URL, {
method: 'DELETE',
signal: controller.signal,
})
.catch(() => {}) // Do nothing with this attempt
.then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000))
.then(() => log(`[NITRO]::Debug: Nitro process is terminated`))
.catch((err) => {
log(
`[NITRO]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
)
throw 'PORT_NOT_AVAILABLE'
.catch(() => {}) // Do nothing with this attempt
.then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000))
.then(() => log(`[NITRO]::Debug: Nitro process is terminated`))
.catch((err) => {
log(
`[NITRO]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
)
throw 'PORT_NOT_AVAILABLE'
})
}

if (subprocess?.pid) {
log(`[NITRO]::Debug: Killing PID ${subprocess.pid}`)
const pid = subprocess.pid
return new Promise((resolve, reject) => {
terminate(pid, function (err) {
if (err) {
return killRequest()
} else {
return tcpPortUsed
.waitUntilFree(PORT, 300, 5000)
.then(() => resolve())
.then(() => log(`[NITRO]::Debug: Nitro process is terminated`))
.catch(() => {
killRequest()
})
}
})
})
} else {
return killRequest()
}
}

/**
Expand Down
2 changes: 1 addition & 1 deletion web/containers/Loader/ModelReload.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export default function ModelReload() {
style={{ width: `${loader}%` }}
/>
<span className="relative z-10">
Reloading model {stateModel.model}
Reloading model {stateModel.model?.id}
</span>
</div>
</div>
Expand Down
2 changes: 1 addition & 1 deletion web/containers/Loader/ModelStart.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ export default function ModelStart() {
<span className="relative z-10">
{stateModel.state === 'start' ? 'Starting' : 'Stopping'}
&nbsp;model&nbsp;
{stateModel.model}
{stateModel.model?.id}
</span>
</div>
</div>
Expand Down
2 changes: 1 addition & 1 deletion web/containers/Providers/EventHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {

const onModelStopped = useCallback(() => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
setStateModel({ state: 'start', loading: false, model: undefined })
}, [setActiveModel, setStateModel])

const updateThreadTitle = useCallback(
Expand Down
63 changes: 44 additions & 19 deletions web/hooks/useActiveModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@ import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
export const activeModelAtom = atom<Model | undefined>(undefined)
export const loadModelErrorAtom = atom<string | undefined>(undefined)

export const stateModelAtom = atom({
type ModelState = {
state: string
loading: boolean
model?: Model
}

export const stateModelAtom = atom<ModelState>({
state: 'start',
loading: false,
model: '',
model: undefined,
})

export function useActiveModel() {
Expand All @@ -35,7 +41,7 @@ export function useActiveModel() {
const startModel = async (modelId: string) => {
if (
(activeModel && activeModel.id === modelId) ||
(stateModel.model === modelId && stateModel.loading)
(stateModel.model?.id === modelId && stateModel.loading)
) {
console.debug(`Model ${modelId} is already initialized. Ignore..`)
return Promise.resolve()
Expand All @@ -52,7 +58,7 @@ export function useActiveModel() {

setActiveModel(undefined)

setStateModel({ state: 'start', loading: true, model: modelId })
setStateModel({ state: 'start', loading: true, model })

if (!model) {
toaster({
Expand All @@ -63,7 +69,7 @@ export function useActiveModel() {
setStateModel(() => ({
state: 'start',
loading: false,
model: '',
model: undefined,
}))

return Promise.reject(`Model ${modelId} not found!`)
Expand All @@ -89,7 +95,7 @@ export function useActiveModel() {
setStateModel(() => ({
state: 'stop',
loading: false,
model: model.id,
model,
}))
toaster({
title: 'Success!',
Expand All @@ -101,7 +107,7 @@ export function useActiveModel() {
setStateModel(() => ({
state: 'start',
loading: false,
model: model.id,
model,
}))

toaster({
Expand All @@ -114,20 +120,39 @@ export function useActiveModel() {
})
}

const stopModel = useCallback(async () => {
if (!activeModel || (stateModel.state === 'stop' && stateModel.loading))
const stopModel = useCallback(
async (model?: Model) => {
const stoppingModel = activeModel || model
if (
!stoppingModel ||
(!model && stateModel.state === 'stop' && stateModel.loading)
)
return

setStateModel({ state: 'stop', loading: true, model: stoppingModel })
const engine = EngineManager.instance().get(stoppingModel.engine)
await engine
?.unloadModel(stoppingModel)
.catch()
.then(() => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: undefined })
})
},
[activeModel, setActiveModel, setStateModel, stateModel]
)

const stopInference = useCallback(async () => {
// Loading model
if (stateModel.loading) {
stopModel(stateModel.model)
return
}
if (!activeModel) return

setStateModel({ state: 'stop', loading: true, model: activeModel.id })
const engine = EngineManager.instance().get(activeModel.engine)
await engine
?.unloadModel(activeModel)
.catch()
.then(() => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
})
}, [activeModel, stateModel, setActiveModel, setStateModel])
engine?.stopInference()
}, [activeModel, stateModel, stopModel])

return { activeModel, startModel, stopModel, stateModel }
return { activeModel, startModel, stopModel, stopInference, stateModel }
}
4 changes: 3 additions & 1 deletion web/hooks/useCreateNewThread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { fileUploadAtom } from '@/containers/Providers/Jotai'

import { generateThreadId } from '@/utils/thread'

import { useActiveModel } from './useActiveModel'
import useRecommendedModel from './useRecommendedModel'

import useSetActiveThread from './useSetActiveThread'
Expand Down Expand Up @@ -65,14 +66,15 @@ export const useCreateNewThread = () => {
const { recommendedModel, downloadedModels } = useRecommendedModel()

const threads = useAtomValue(threadsAtom)
const { stopInference } = useActiveModel()

const requestCreateNewThread = async (
assistant: Assistant,
model?: Model | undefined
) => {
// Stop generating if any
setIsGeneratingResponse(false)
events.emit(InferenceEvent.OnInferenceStopped, {})
stopInference()

const defaultModel = model ?? recommendedModel ?? downloadedModels[0]

Expand Down
5 changes: 3 additions & 2 deletions web/screens/Chat/ChatInput/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import {

const ChatInput: React.FC = () => {
const activeThread = useAtomValue(activeThreadAtom)
const { stateModel } = useActiveModel()
const { stateModel, activeModel } = useActiveModel()
const messages = useAtomValue(getCurrentChatMessagesAtom)

const [currentPrompt, setCurrentPrompt] = useAtom(currentPromptAtom)
Expand All @@ -60,6 +60,7 @@ const ChatInput: React.FC = () => {
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom)
const threadStates = useAtomValue(threadStatesAtom)
const { stopInference } = useActiveModel()

const isStreamingResponse = Object.values(threadStates).some(
(threadState) => threadState.waitingForResponse
Expand Down Expand Up @@ -107,7 +108,7 @@ const ChatInput: React.FC = () => {
}

const onStopInferenceClick = async () => {
events.emit(InferenceEvent.OnInferenceStopped, {})
stopInference()
}

/**
Expand Down
4 changes: 2 additions & 2 deletions web/screens/Chat/EditChatInput/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type Props = {

const EditChatInput: React.FC<Props> = ({ message }) => {
const activeThread = useAtomValue(activeThreadAtom)
const { stateModel } = useActiveModel()
const { stateModel, stopInference } = useActiveModel()
const messages = useAtomValue(getCurrentChatMessagesAtom)

const [editPrompt, setEditPrompt] = useAtom(editPromptAtom)
Expand Down Expand Up @@ -127,7 +127,7 @@ const EditChatInput: React.FC<Props> = ({ message }) => {
}

const onStopInferenceClick = async () => {
events.emit(InferenceEvent.OnInferenceStopped, {})
stopInference()
}

return (
Expand Down
3 changes: 2 additions & 1 deletion web/screens/Chat/LoadModelError/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ const LoadModelError = () => {
<ModalTroubleShooting />
</div>
) : loadModelError &&
loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
typeof loadModelError.includes === 'function' &&
loadModelError.includes('EXTENSION_IS_NOT_INSTALLED') ? (
<div className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500">
<p className="w-[90%]">
Model is currently unavailable. Please switch to a different model
Expand Down
4 changes: 2 additions & 2 deletions web/screens/Settings/Models/Row.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ export default function RowModel(props: RowModelProps) {
const { activeModel, startModel, stopModel, stateModel } = useActiveModel()
const { deleteModel } = useDeleteModel()

const isActiveModel = stateModel.model === props.data.id
const isActiveModel = stateModel.model?.id === props.data.id

const [serverEnabled, setServerEnabled] = useAtom(serverEnabledAtom)

Expand Down Expand Up @@ -84,7 +84,7 @@ export default function RowModel(props: RowModelProps) {
<span className="h-2 w-2 rounded-full bg-green-500" />
<span>Active</span>
</Badge>
) : stateModel.loading && stateModel.model === props.data.id ? (
) : stateModel.loading && stateModel.model?.id === props.data.id ? (
<Badge
className="inline-flex items-center space-x-2"
themes="secondary"
Expand Down

0 comments on commit 1eaf13b

Please sign in to comment.