Skip to content

Commit

Permalink
feat: preserve model settings (janhq#3427)
Browse files Browse the repository at this point in the history
* feat: preserve model settings

* feat: preserve model settings across new threads

* chore: lint fix

* fix: feature toggle off should also affect default value retrieve
  • Loading branch information
louis-jan authored Aug 21, 2024
1 parent c8474c8 commit ad9a4a0
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 13 deletions.
3 changes: 3 additions & 0 deletions core/src/types/model/modelEntity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ export type ModelMetadata = {
tags: string[]
size: number
cover?: string
// These settings to preserve model settings across threads
default_ctx_len?: number
default_max_tokens?: number
}

/**
Expand Down
2 changes: 1 addition & 1 deletion extensions/model-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ export default class JanModelExtension extends ModelExtension {
}

/**
* Saves a machine learning model.
* Saves a model file.
* @param model - The model to save.
* @returns A Promise that resolves when the model is saved.
*/
Expand Down
15 changes: 14 additions & 1 deletion web/containers/ModelDropdown/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import {

import { extensionManager } from '@/extension'

import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
import {
configuredModelsAtom,
Expand Down Expand Up @@ -89,6 +90,7 @@ const ModelDropdown = ({
const featuredModel = configuredModels.filter((x) =>
x.metadata.tags.includes('Featured')
)
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)

useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [
dropdownOptions,
Expand Down Expand Up @@ -161,14 +163,25 @@ const ModelDropdown = ({
if (activeThread) {
// Default setting ctx_len for the model for a better onboarding experience
// TODO: When Cortex support hardware instructions, we should remove this
const defaultContextLength = preserveModelSettings
? model?.metadata?.default_ctx_len
: 2048
const defaultMaxTokens = preserveModelSettings
? model?.metadata?.default_max_tokens
: 2048
const overriddenSettings =
model?.settings.ctx_len && model.settings.ctx_len > 2048
? { ctx_len: 2048 }
? { ctx_len: defaultContextLength }
: {}
const overriddenParameters =
model?.parameters.max_tokens && model.parameters.max_tokens
? { max_tokens: defaultMaxTokens }
: {}

const modelParams = {
...model?.parameters,
...model?.settings,
...overriddenParameters,
...overriddenSettings,
}

Expand Down
7 changes: 7 additions & 0 deletions web/helpers/atoms/AppConfig.atom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const VULKAN_ENABLED = 'vulkanEnabled'
const IGNORE_SSL = 'ignoreSSLFeature'
const HTTPS_PROXY_FEATURE = 'httpsProxyFeature'
const QUICK_ASK_ENABLED = 'quickAskEnabled'
const PRESERVE_MODEL_SETTINGS = 'preserveModelSettings'

export const janDataFolderPathAtom = atom('')

Expand All @@ -23,3 +24,9 @@ export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false)
export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false)

export const hostAtom = atom('http://localhost:1337/')

// This feature is to allow user to cache model settings on thread creation
export const preserveModelSettingsAtom = atomWithStorage(
PRESERVE_MODEL_SETTINGS,
false
)
11 changes: 11 additions & 0 deletions web/helpers/atoms/Model.atom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ export const removeDownloadingModelAtom = atom(

export const downloadedModelsAtom = atom<Model[]>([])

export const updateDownloadedModelAtom = atom(
null,
(get, set, updatedModel: Model) => {
const models: Model[] = get(downloadedModelsAtom).map((c) =>
c.id === updatedModel.id ? updatedModel : c
)

set(downloadedModelsAtom, models)
}
)

export const removeDownloadedModelAtom = atom(
null,
(get, set, modelId: string) => {
Expand Down
19 changes: 14 additions & 5 deletions web/hooks/useCreateNewThread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
Model,
AssistantTool,
} from '@janhq/core'
import { atom, useAtomValue, useSetAtom } from 'jotai'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'

import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction'
import { fileUploadAtom } from '@/containers/Providers/Jotai'
Expand All @@ -24,7 +24,10 @@ import useSetActiveThread from './useSetActiveThread'

import { extensionManager } from '@/extension'

import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import {
experimentalFeatureEnabledAtom,
preserveModelSettingsAtom,
} from '@/helpers/atoms/AppConfig.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
threadsAtom,
Expand Down Expand Up @@ -62,6 +65,7 @@ export const useCreateNewThread = () => {
const copyOverInstructionEnabled = useAtomValue(
copyOverInstructionEnabledAtom
)
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
const activeThread = useAtomValue(activeThreadAtom)

const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
Expand Down Expand Up @@ -99,15 +103,20 @@ export const useCreateNewThread = () => {
enabled: true,
settings: assistant.tools && assistant.tools[0].settings,
}

const defaultContextLength = preserveModelSettings
? model?.metadata?.default_ctx_len
: 2048
const defaultMaxTokens = preserveModelSettings
? model?.metadata?.default_max_tokens
: 2048
const overriddenSettings =
defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048
? { ctx_len: 2048 }
? { ctx_len: defaultContextLength }
: {}

const overriddenParameters =
defaultModel?.parameters.max_tokens && defaultModel.parameters.max_tokens
? { max_tokens: 2048 }
? { max_tokens: defaultMaxTokens }
: {}

const createdAt = Date.now()
Expand Down
46 changes: 40 additions & 6 deletions web/hooks/useUpdateModelParameters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@ import {
ConversationalExtension,
ExtensionTypeEnum,
InferenceEngine,
Model,
ModelExtension,
Thread,
ThreadAssistantInfo,
} from '@janhq/core'

import { useAtomValue, useSetAtom } from 'jotai'
import { useAtom, useAtomValue, useSetAtom } from 'jotai'

import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'

import { extensionManager } from '@/extension'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
import {
selectedModelAtom,
updateDownloadedModelAtom,
} from '@/helpers/atoms/Model.atom'
import {
ModelParams,
getActiveThreadModelParamsAtom,
Expand All @@ -28,8 +34,10 @@ export type UpdateModelParameter = {

export default function useUpdateModelParameters() {
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
const selectedModel = useAtomValue(selectedModelAtom)
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
const updateDownloadedModel = useSetAtom(updateDownloadedModelAtom)
const preserveModelFeatureEnabled = useAtomValue(preserveModelSettingsAtom)

const updateModelParameter = useCallback(
async (thread: Thread, settings: UpdateModelParameter) => {
Expand All @@ -40,12 +48,11 @@ export default function useUpdateModelParameters() {

// update the state
setThreadModelParams(thread.id, updatedModelParams)
const runtimeParams = toRuntimeParams(updatedModelParams)
const settingParams = toSettingParams(updatedModelParams)

const assistants = thread.assistants.map(
(assistant: ThreadAssistantInfo) => {
const runtimeParams = toRuntimeParams(updatedModelParams)
const settingParams = toSettingParams(updatedModelParams)

assistant.model.parameters = runtimeParams
assistant.model.settings = settingParams
if (selectedModel) {
Expand All @@ -65,6 +72,33 @@ export default function useUpdateModelParameters() {
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread(updatedThread)

// Persists default settings to model file
// Do not overwrite ctx_len and max_tokens
if (preserveModelFeatureEnabled && selectedModel) {
const updatedModel = {
...selectedModel,
parameters: {
...runtimeParams,
max_tokens: selectedModel.parameters.max_tokens,
},
settings: {
...settingParams,
ctx_len: selectedModel.settings.ctx_len,
},
metadata: {
...selectedModel.metadata,
default_ctx_len: settingParams.ctx_len,
default_max_tokens: runtimeParams.max_tokens,
},
} as Model

await extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.saveModel(updatedModel)
setSelectedModel(updatedModel)
updateDownloadedModel(updatedModel)
}
},
[activeModelParams, selectedModel, setThreadModelParams]
)
Expand Down
25 changes: 25 additions & 0 deletions web/screens/Settings/Advanced/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import {
proxyEnabledAtom,
vulkanEnabledAtom,
quickAskEnabledAtom,
preserveModelSettingsAtom,
} from '@/helpers/atoms/AppConfig.atom'

type GPU = {
Expand Down Expand Up @@ -64,6 +65,9 @@ const Advanced = () => {
const [proxyEnabled, setProxyEnabled] = useAtom(proxyEnabledAtom)
const quickAskEnabled = useAtomValue(quickAskEnabledAtom)

const [preserveModelSettings, setPreserveModelSettings] = useAtom(
preserveModelSettingsAtom
)
const [proxy, setProxy] = useAtom(proxyAtom)
const [ignoreSSL, setIgnoreSSL] = useAtom(ignoreSslAtom)

Expand Down Expand Up @@ -385,6 +389,27 @@ const Advanced = () => {
</div>
)}

{experimentalEnabled && (
<div className="flex w-full flex-col items-start justify-between gap-4 border-b border-[hsla(var(--app-border))] py-4 first:pt-0 last:border-none sm:flex-row">
<div className="flex-shrink-0 space-y-1">
<div className="flex gap-x-2">
<h6 className="font-semibold capitalize">
Preserve Model Settings
</h6>
</div>
<p className="font-medium leading-relaxed text-[hsla(var(--text-secondary))]">
Save model settings changes directly to the model file so that
new threads will reuse the previous settings.
</p>
</div>

<Switch
checked={preserveModelSettings}
onChange={(e) => setPreserveModelSettings(e.target.checked)}
/>
</div>
)}

<DataFolder />

{/* Proxy */}
Expand Down

0 comments on commit ad9a4a0

Please sign in to comment.