Skip to content

Commit

Permalink
fix: Use Events for init, load, stop models
Browse files Browse the repository at this point in the history
  • Loading branch information
hiro-v committed Dec 8, 2023
1 parent 9aca37a commit 1bc5fe6
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 55 deletions.
8 changes: 8 additions & 0 deletions core/src/events.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ export enum EventName {
OnMessageResponse = "OnMessageResponse",
/** The `OnMessageUpdate` event is emitted when a message is updated. */
OnMessageUpdate = "OnMessageUpdate",
/** The `OnModelInit` event is emitted when a model inits. */
OnModelInit = "OnModelInit",
/** The `OnModelReady` event is emitted when a model ready. */
OnModelReady = "OnModelReady",
/** The `OnModelFail` event is emitted when a model fails loading. */
OnModelFail = "OnModelFail",
/** The `OnModelStop` event is emitted when a model fails loading. */
OnModelStop = "OnModelStop",
}

/**
Expand Down
15 changes: 15 additions & 0 deletions core/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ export type ThreadState = {
error?: Error;
lastMessage?: string;
};
/**
* Represents the inference engine.
* @stored
*/

enum InferenceEngine {
llama_cpp = "llama_cpp",
openai = "openai",
nvidia_triton = "nvidia_triton",
hf_endpoint = "hf_endpoint",
}

/**
* Model type defines the shape of a model object.
Expand Down Expand Up @@ -234,6 +245,10 @@ export interface Model {
* Metadata of the model.
*/
metadata: ModelMetadata;
/**
* The model engine. Enum: "llamacpp" "openai"
*/
engine: InferenceEngine;
}

export type ModelMetadata = {
Expand Down
49 changes: 48 additions & 1 deletion web/containers/Providers/EventHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import {
ThreadMessage,
ExtensionType,
MessageStatus,
Model
} from '@janhq/core'
import { ConversationalExtension } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'

import { extensionManager } from '@/extension'
import {
Expand All @@ -21,9 +22,16 @@ import {
threadsAtom,
} from '@/helpers/atoms/Conversation.atom'

import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels'
import { toaster } from '../Toast'

export default function EventHandler({ children }: { children: ReactNode }) {
const addNewMessage = useSetAtom(addNewMessageAtom)
const updateMessage = useSetAtom(updateMessageAtom)
const { downloadedModels } = useGetDownloadedModels()
const [activeModel, setActiveModel] = useAtom(activeModelAtom)
const [stateModel, setStateModel] = useAtom(stateModelAtom)

const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
const threads = useAtomValue(threadsAtom)
Expand All @@ -37,6 +45,42 @@ export default function EventHandler({ children }: { children: ReactNode }) {
addNewMessage(message)
}

async function handleModelReady(res: any) {
const model = downloadedModels.find((e) => e.id === res.modelId)
setActiveModel(model)
toaster({
title: 'Success!',
description: `Model ${res.modelId} has been started.`,
})
setStateModel(() => ({
state: 'stop',
loading: false,
model: res.modelId,
}))
}

async function handleModelStop(res: any) {
const model = downloadedModels.find((e) => e.id === res.modelId)
setTimeout(async () => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
toaster({
title: 'Success!',
description: `Model ${res.modelId} has been stopped.`,
})
}, 500)
}

async function handleModelFail(res: any) {
const errorMessage = `${res.error}`
alert(errorMessage)
setStateModel(() => ({
state: 'start',
loading: false,
model: res.modelId,
}))
}

async function handleMessageResponseUpdate(message: ThreadMessage) {
updateMessage(
message.id,
Expand Down Expand Up @@ -73,6 +117,9 @@ export default function EventHandler({ children }: { children: ReactNode }) {
if (window.core.events) {
events.on(EventName.OnMessageResponse, handleNewMessageResponse)
events.on(EventName.OnMessageUpdate, handleMessageResponseUpdate)
events.on(EventName.OnModelReady, handleModelReady)
events.on(EventName.OnModelFail, handleModelFail)
events.on(EventName.OnModelStop, handleModelStop)
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
Expand Down
56 changes: 5 additions & 51 deletions web/hooks/useActiveModel.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { ExtensionType, InferenceExtension } from '@janhq/core'
import { EventName, ExtensionType, InferenceExtension, events } from '@janhq/core'
import { Model, ModelSettingParams } from '@janhq/core'
import { atom, useAtom } from 'jotai'

Expand All @@ -9,9 +9,9 @@ import { useGetDownloadedModels } from './useGetDownloadedModels'

import { extensionManager } from '@/extension'

const activeModelAtom = atom<Model | undefined>(undefined)
export const activeModelAtom = atom<Model | undefined>(undefined)

const stateModelAtom = atom({ state: 'start', loading: false, model: '' })
export const stateModelAtom = atom({ state: 'start', loading: false, model: '' })

export function useActiveModel() {
const [activeModel, setActiveModel] = useAtom(activeModelAtom)
Expand Down Expand Up @@ -47,59 +47,13 @@ export function useActiveModel() {
return
}

const currentTime = Date.now()
const res = await initModel(modelId, model?.settings)
if (res && res.error) {
const errorMessage = `${res.error}`
alert(errorMessage)
setStateModel(() => ({
state: 'start',
loading: false,
model: modelId,
}))
} else {
console.debug(
`Model ${modelId} successfully initialized! Took ${
Date.now() - currentTime
}ms`
)
setActiveModel(model)
toaster({
title: 'Success!',
description: `Model ${modelId} has been started.`,
})
setStateModel(() => ({
state: 'stop',
loading: false,
model: modelId,
}))
}
events.emit(EventName.OnModelInit, model)
}

const stopModel = async (modelId: string) => {
setStateModel({ state: 'stop', loading: true, model: modelId })
setTimeout(async () => {
extensionManager
.get<InferenceExtension>(ExtensionType.Inference)
?.stopModel()

setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
toaster({
title: 'Success!',
description: `Model ${modelId} has been stopped.`,
})
}, 500)
events.emit(EventName.OnModelStop, modelId)
}

return { activeModel, startModel, stopModel, stateModel }
}

const initModel = async (
modelId: string,
settings?: ModelSettingParams
): Promise<any> => {
return extensionManager
.get<InferenceExtension>(ExtensionType.Inference)
?.initModel(modelId, settings)
}
20 changes: 17 additions & 3 deletions web/screens/ExploreModels/ExploreModelItemHeader/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,23 @@ const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {

const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null

let downloadButton = (
<Button onClick={() => onDownloadClick()}>Download</Button>
)
let downloadButton;

if (model.engine !== 'nitro') {
downloadButton = (
<Button onClick={() => onDownloadClick()}>
Use
</Button>
);
} else if (model.engine === 'nitro') {
downloadButton = (
<Button onClick={() => onDownloadClick()}>
{model.metadata.size
? `Download (${toGigabytes(model.metadata.size)})`
: 'Download'}
</Button>
);
}

const onUseModelClick = () => {
startModel(model.id)
Expand Down

0 comments on commit 1bc5fe6

Please sign in to comment.