Skip to content

Commit

Permalink
chore: default context length to 2048 (janhq#2746)
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai authored Apr 17, 2024
1 parent a2cb135 commit 9563278
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 34 deletions.
1 change: 1 addition & 0 deletions core/src/browser/extensions/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
abstract importModels(models: ImportingModel[], optionType: OptionType): Promise<void>
abstract updateModelInfo(modelInfo: Partial<Model>): Promise<Model>
abstract fetchHuggingFaceRepoData(repoId: string): Promise<HuggingFaceRepoData>
abstract getDefaultModel(): Promise<Model>
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"min": 0,
"max": 4096,
"step": 128,
"value": 4096
"value": 2048
}
}
]
2 changes: 1 addition & 1 deletion extensions/model-extension/resources/default-model.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"created": 0,
"description": "User self import model",
"settings": {
"ctx_len": 4096,
"ctx_len": 2048,
"embedding": false,
"prompt_template": "{system_message}\n### Instruction: {prompt}\n### Response:",
"llama_model_path": "N/A"
Expand Down
2 changes: 1 addition & 1 deletion extensions/model-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ export default class JanModelExtension extends ModelExtension {
return model
}

private async getDefaultModel(): Promise<Model> {
override async getDefaultModel(): Promise<Model> {
const defaultModel = DEFAULT_MODEL as Model
return defaultModel
}
Expand Down
2 changes: 2 additions & 0 deletions web/helpers/atoms/Model.atom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ export const removeDownloadedModelAtom = atom(

export const configuredModelsAtom = atom<Model[]>([])

export const defaultModelAtom = atom<Model | undefined>(undefined)

/// TODO: move this part to another atom
// store the paths of the models that are being imported
export const importingModelsAtom = atom<ImportingModel[]>([])
Expand Down
23 changes: 20 additions & 3 deletions web/hooks/useModels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,37 @@ import { useSetAtom } from 'jotai'
import { extensionManager } from '@/extension'
import {
configuredModelsAtom,
defaultModelAtom,
downloadedModelsAtom,
} from '@/helpers/atoms/Model.atom'

const useModels = () => {
const setDownloadedModels = useSetAtom(downloadedModelsAtom)
const setConfiguredModels = useSetAtom(configuredModelsAtom)
const setDefaultModel = useSetAtom(defaultModelAtom)

const getData = useCallback(() => {
const getDownloadedModels = async () => {
const models = await getLocalDownloadedModels()
setDownloadedModels(models)
}

const getConfiguredModels = async () => {
const models = await getLocalConfiguredModels()
setConfiguredModels(models)
}
getDownloadedModels()
getConfiguredModels()
}, [setDownloadedModels, setConfiguredModels])

const getDefaultModel = async () => {
const defaultModel = await getLocalDefaultModel()
setDefaultModel(defaultModel)
}

Promise.all([
getDownloadedModels(),
getConfiguredModels(),
getDefaultModel(),
])
}, [setDownloadedModels, setConfiguredModels, setDefaultModel])

useEffect(() => {
// Try get data on mount
Expand All @@ -46,6 +58,11 @@ const useModels = () => {
}, [getData])
}

const getLocalDefaultModel = async (): Promise<Model | undefined> =>
extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.getDefaultModel()

const getLocalConfiguredModels = async (): Promise<Model[]> =>
extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import { useCallback, useMemo } from 'react'
import {
DownloadState,
HuggingFaceRepoData,
InferenceEngine,
Model,
Quantization,
} from '@janhq/core'
Expand All @@ -23,7 +22,10 @@ import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'

import { importHuggingFaceModelStageAtom } from '@/helpers/atoms/HuggingFace.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import {
defaultModelAtom,
downloadedModelsAtom,
} from '@/helpers/atoms/Model.atom'

type Props = {
index: number
Expand Down Expand Up @@ -52,15 +54,15 @@ const ModelDownloadRow: React.FC<Props> = ({
const isDownloaded = downloadedModels.find((md) => md.id === fileName) != null

const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom)
const defaultModel = useAtomValue(defaultModelAtom)

const model = useMemo(() => {
const promptData: string =
(repoData.cardData['prompt_template'] as string) ??
'{system_message}\n### Instruction: {prompt}\n### Response:'
if (!defaultModel) {
return undefined
}

const model: Model = {
object: 'model',
version: '1.0',
format: 'gguf',
...defaultModel,
sources: [
{
url: downloadUrl,
Expand All @@ -70,38 +72,26 @@ const ModelDownloadRow: React.FC<Props> = ({
id: fileName,
name: fileName,
created: Date.now(),
description: 'User self import model',
settings: {
ctx_len: 4096,
embedding: false,
prompt_template: promptData,
llama_model_path: 'N/A',
},
parameters: {
temperature: 0.7,
top_p: 0.95,
stream: true,
max_tokens: 2048,
stop: ['<endofstring>'],
frequency_penalty: 0.7,
presence_penalty: 0,
},
metadata: {
author: 'User',
tags: repoData.tags,
size: fileSize,
},
engine: InferenceEngine.nitro,
}
console.log('NamH model: ', JSON.stringify(model))
return model
}, [fileName, fileSize, repoData, downloadUrl])
}, [fileName, fileSize, repoData, downloadUrl, defaultModel])

const onAbortDownloadClick = useCallback(() => {
abortModelDownload(model)
if (model) {
abortModelDownload(model)
}
}, [model, abortModelDownload])

const onDownloadClick = useCallback(async () => {
downloadModel(model)
if (model) {
downloadModel(model)
}
}, [model, downloadModel])

const onUseModelClick = useCallback(async () => {
Expand All @@ -120,6 +110,10 @@ const ModelDownloadRow: React.FC<Props> = ({
setHfImportingStage,
])

if (!model) {
return null
}

return (
<div className="flex w-[662px] flex-row items-center justify-between space-x-1 rounded border border-border p-3">
<div className="flex">
Expand Down

0 comments on commit 9563278

Please sign in to comment.