Skip to content

Commit

Permalink
feat: add engine settings (janhq#1199)
Browse files Browse the repository at this point in the history
* feat: add engine settings

Signed-off-by: James <[email protected]>
---------

Signed-off-by: James <[email protected]>
Co-authored-by: Louis <[email protected]>
  • Loading branch information
namchuai and louis-jan authored Dec 28, 2023
1 parent 2df43e9 commit c580b4c
Show file tree
Hide file tree
Showing 17 changed files with 442 additions and 163 deletions.
29 changes: 7 additions & 22 deletions web/containers/Checkbox/index.tsx
Original file line number Diff line number Diff line change
@@ -1,48 +1,33 @@
import { FieldValues, UseFormRegister } from 'react-hook-form'
import React from 'react'

import { ModelRuntimeParams } from '@janhq/core'
import { Switch } from '@janhq/uikit'

import { useAtomValue } from 'jotai'

import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'

import {
getActiveThreadIdAtom,
getActiveThreadModelRuntimeParamsAtom,
} from '@/helpers/atoms/Thread.atom'
import { getActiveThreadIdAtom } from '@/helpers/atoms/Thread.atom'

type Props = {
name: string
title: string
checked: boolean
register: UseFormRegister<FieldValues>
}

const Checkbox: React.FC<Props> = ({ name, title, checked, register }) => {
const Checkbox: React.FC<Props> = ({ name, title, checked }) => {
const { updateModelParameter } = useUpdateModelParameters()
const threadId = useAtomValue(getActiveThreadIdAtom)
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)

const onCheckedChange = (checked: boolean) => {
if (!threadId || !activeModelParams) return
if (!threadId) return

const updatedModelParams: ModelRuntimeParams = {
...activeModelParams,
[name]: checked,
}

updateModelParameter(threadId, updatedModelParams)
updateModelParameter(threadId, name, checked)
}

return (
<div className="flex justify-between">
<label>{title}</label>
<Switch
checked={checked}
{...register(name)}
onCheckedChange={onCheckedChange}
/>
<p className="mb-2 text-sm font-semibold text-gray-600">{title}</p>
<Switch checked={checked} onCheckedChange={onCheckedChange} />
</div>
)
}
Expand Down
42 changes: 37 additions & 5 deletions web/containers/DropdownListSidebar/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,22 @@ import useRecommendedModel from '@/hooks/useRecommendedModel'

import { toGigabytes } from '@/utils/converter'

import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom'
import {
activeThreadAtom,
getActiveThreadIdAtom,
setThreadModelParamsAtom,
threadStatesAtom,
} from '@/helpers/atoms/Thread.atom'

export const selectedModelAtom = atom<Model | undefined>(undefined)

export default function DropdownListSidebar() {
const setSelectedModel = useSetAtom(selectedModelAtom)
const threadStates = useAtomValue(threadStatesAtom)
const activeThreadId = useAtomValue(getActiveThreadIdAtom)
const activeThread = useAtomValue(activeThreadAtom)
const threadStates = useAtomValue(threadStatesAtom)
const setSelectedModel = useSetAtom(selectedModelAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)

const [selected, setSelected] = useState<Model | undefined>()
const { setMainViewState } = useMainViewState()
const [openAISettings, setOpenAISettings] = useState<
Expand All @@ -54,15 +62,39 @@ export default function DropdownListSidebar() {
useEffect(() => {
setSelected(recommendedModel)
setSelectedModel(recommendedModel)
}, [recommendedModel, setSelectedModel])

if (activeThread) {
const finishInit = threadStates[activeThread.id].isFinishInit ?? true
if (finishInit) return
const modelParams = {
...recommendedModel?.parameters,
...recommendedModel?.settings,
}
setThreadModelParams(activeThread.id, modelParams)
}
}, [
recommendedModel,
activeThread,
setSelectedModel,
setThreadModelParams,
threadStates,
])

const onValueSelected = useCallback(
(modelId: string) => {
const model = downloadedModels.find((m) => m.id === modelId)
setSelected(model)
setSelectedModel(model)

if (activeThreadId) {
const modelParams = {
...model?.parameters,
...model?.settings,
}
setThreadModelParams(activeThreadId, modelParams)
}
},
[downloadedModels, setSelectedModel]
[downloadedModels, activeThreadId, setSelectedModel, setThreadModelParams]
)

if (!activeThread) {
Expand Down
43 changes: 43 additions & 0 deletions web/containers/ModelConfigInput/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { Textarea } from '@janhq/uikit'

import { useAtomValue } from 'jotai'

import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'

import { getActiveThreadIdAtom } from '@/helpers/atoms/Thread.atom'

type Props = {
title: string
name: string
placeholder: string
value: string
}

const ModelConfigInput: React.FC<Props> = ({
title,
name,
value,
placeholder,
}) => {
const { updateModelParameter } = useUpdateModelParameters()
const threadId = useAtomValue(getActiveThreadIdAtom)

const onValueChanged = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
if (!threadId) return

updateModelParameter(threadId, name, e.target.value)
}

return (
<div className="flex flex-col">
<p className="mb-2 text-sm font-semibold text-gray-600">{title}</p>
<Textarea
placeholder={placeholder}
onChange={onValueChanged}
value={value}
/>
</div>
)
}

export default ModelConfigInput
23 changes: 4 additions & 19 deletions web/containers/Slider/index.tsx
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import { FieldValues, UseFormRegister } from 'react-hook-form'
import React from 'react'

import { ModelRuntimeParams } from '@janhq/core'
import { Slider, Input } from '@janhq/uikit'
import { useAtomValue } from 'jotai'

import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'

import {
getActiveThreadIdAtom,
getActiveThreadModelRuntimeParamsAtom,
} from '@/helpers/atoms/Thread.atom'
import { getActiveThreadIdAtom } from '@/helpers/atoms/Thread.atom'

type Props = {
name: string
Expand All @@ -18,7 +14,6 @@ type Props = {
max: number
step: number
value: number
register: UseFormRegister<FieldValues>
}

const SliderRightPanel: React.FC<Props> = ({
Expand All @@ -28,21 +23,14 @@ const SliderRightPanel: React.FC<Props> = ({
max,
step,
value,
register,
}) => {
const { updateModelParameter } = useUpdateModelParameters()
const threadId = useAtomValue(getActiveThreadIdAtom)
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)

const onValueChanged = (e: number[]) => {
if (!threadId || !activeModelParams) return
if (!threadId) return

const updatedModelParams: ModelRuntimeParams = {
...activeModelParams,
[name]: Number(e[0]),
}

updateModelParameter(threadId, updatedModelParams)
updateModelParameter(threadId, name, e[0])
}

return (
Expand All @@ -51,9 +39,6 @@ const SliderRightPanel: React.FC<Props> = ({
<div className="flex items-center gap-x-4">
<div className="relative w-full">
<Slider
{...register(name, {
setValueAs: (v: string) => parseInt(v),
})}
value={[value]}
onValueChange={onValueChanged}
min={min}
Expand Down
35 changes: 16 additions & 19 deletions web/helpers/atoms/Thread.atom.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {
ModelRuntimeParams,
ModelSettingParams,
Thread,
ThreadContent,
ThreadState,
Expand Down Expand Up @@ -110,30 +111,26 @@ export const activeThreadAtom = atom<Thread | undefined>((get) =>
/**
* Store model params at thread level settings
*/
export const threadModelRuntimeParamsAtom = atom<
Record<string, ModelRuntimeParams>
>({})
export const threadModelParamsAtom = atom<Record<string, ModelParams>>({})

export const getActiveThreadModelRuntimeParamsAtom = atom<
ModelRuntimeParams | undefined
>((get) => {
const threadId = get(activeThreadIdAtom)
if (!threadId) {
console.debug('Active thread id is undefined')
return undefined
}
export type ModelParams = ModelRuntimeParams | ModelSettingParams

return get(threadModelRuntimeParamsAtom)[threadId]
})
export const getActiveThreadModelParamsAtom = atom<ModelParams | undefined>(
(get) => {
const threadId = get(activeThreadIdAtom)
if (!threadId) {
console.debug('Active thread id is undefined')
return undefined
}

export const getThreadModelRuntimeParamsAtom = atom(
(get, threadId: string) => get(threadModelRuntimeParamsAtom)[threadId]
return get(threadModelParamsAtom)[threadId]
}
)

export const setThreadModelRuntimeParamsAtom = atom(
export const setThreadModelParamsAtom = atom(
null,
(get, set, threadId: string, params: ModelRuntimeParams) => {
const currentState = { ...get(threadModelRuntimeParamsAtom) }
(get, set, threadId: string, params: ModelParams) => {
const currentState = { ...get(threadModelParamsAtom) }
currentState[threadId] = params
console.debug(
`Update model params for thread ${threadId}, ${JSON.stringify(
Expand All @@ -142,6 +139,6 @@ export const setThreadModelRuntimeParamsAtom = atom(
2
)}`
)
set(threadModelRuntimeParamsAtom, currentState)
set(threadModelParamsAtom, currentState)
}
)
11 changes: 1 addition & 10 deletions web/hooks/useCreateNewThread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import {
setActiveThreadIdAtom,
threadStatesAtom,
updateThreadAtom,
setThreadModelRuntimeParamsAtom,
} from '@/helpers/atoms/Thread.atom'

const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
Expand All @@ -45,10 +44,6 @@ export const useCreateNewThread = () => {
const createNewThread = useSetAtom(createNewThreadAtom)
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
const updateThread = useSetAtom(updateThreadAtom)
const setThreadModelRuntimeParams = useSetAtom(
setThreadModelRuntimeParamsAtom
)

const { deleteThread } = useDeleteThread()

const requestCreateNewThread = async (
Expand Down Expand Up @@ -77,10 +72,7 @@ export const useCreateNewThread = () => {
model: {
id: modelId,
settings: {},
parameters: {
stream: true,
max_tokens: 1024,
},
parameters: {},
engine: undefined,
},
instructions: assistant.instructions,
Expand All @@ -94,7 +86,6 @@ export const useCreateNewThread = () => {
created: createdAt,
updated: createdAt,
}
setThreadModelRuntimeParams(thread.id, assistantInfo.model.parameters)

// add the new thread on top of the thread list to the state
createNewThread(thread)
Expand Down
5 changes: 2 additions & 3 deletions web/hooks/useRecommendedModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ export default function useRecommendedModel() {
const getRecommendedModel = useCallback(async (): Promise<
Model | undefined
> => {
const models = await getAndSortDownloadedModels()
if (!activeThread) {
return
}

const finishInit = threadStates[activeThread.id].isFinishInit ?? true
if (finishInit) {
const modelId = activeThread.assistants[0]?.model.id
const models = await getAndSortDownloadedModels()
const model = models.find((model) => model.id === modelId)

if (model) {
Expand All @@ -60,7 +60,6 @@ export default function useRecommendedModel() {
} else {
const modelId = activeThread.assistants[0]?.model.id
if (modelId !== '*') {
const models = await getAndSortDownloadedModels()
const model = models.find((model) => model.id === modelId)

if (model) {
Expand All @@ -78,7 +77,7 @@ export default function useRecommendedModel() {
}

// sort the model, for display purpose
const models = await getAndSortDownloadedModels()

if (models.length === 0) {
// if we have no downloaded models, then can't recommend anything
console.debug("No downloaded models, can't recommend anything")
Expand Down
Loading

0 comments on commit c580b4c

Please sign in to comment.