diff --git a/.gitignore b/.gitignore index 646e6842a7..eaee28a62a 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ core/test_results.html coverage .yarn .yarnrc +*.tsbuildinfo diff --git a/core/src/browser/core.test.ts b/core/src/browser/core.test.ts index 84250888ec..f38cc0b404 100644 --- a/core/src/browser/core.test.ts +++ b/core/src/browser/core.test.ts @@ -1,98 +1,109 @@ -import { openExternalUrl } from './core'; -import { joinPath } from './core'; -import { openFileExplorer } from './core'; -import { getJanDataFolderPath } from './core'; -import { abortDownload } from './core'; -import { getFileSize } from './core'; -import { executeOnMain } from './core'; +import { openExternalUrl } from './core' +import { joinPath } from './core' +import { openFileExplorer } from './core' +import { getJanDataFolderPath } from './core' +import { abortDownload } from './core' +import { getFileSize } from './core' +import { executeOnMain } from './core' -it('should open external url', async () => { - const url = 'http://example.com'; - globalThis.core = { - api: { - openExternalUrl: jest.fn().mockResolvedValue('opened') +describe('test core apis', () => { + it('should open external url', async () => { + const url = 'http://example.com' + globalThis.core = { + api: { + openExternalUrl: jest.fn().mockResolvedValue('opened'), + }, } - }; - const result = await openExternalUrl(url); - expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url); - expect(result).toBe('opened'); -}); + const result = await openExternalUrl(url) + expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url) + expect(result).toBe('opened') + }) - -it('should join paths', async () => { - const paths = ['/path/one', '/path/two']; - globalThis.core = { - api: { - joinPath: jest.fn().mockResolvedValue('/path/one/path/two') + it('should join paths', async () => { + const paths = ['/path/one', '/path/two'] + globalThis.core = { + api: { + joinPath: jest.fn().mockResolvedValue('/path/one/path/two'), + }, } - }; - const result = await joinPath(paths); - expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths); - expect(result).toBe('/path/one/path/two'); -}); - + const result = await joinPath(paths) + expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths) + expect(result).toBe('/path/one/path/two') + }) -it('should open file explorer', async () => { - const path = '/path/to/open'; - globalThis.core = { - api: { - openFileExplorer: jest.fn().mockResolvedValue('opened') + it('should open file explorer', async () => { + const path = '/path/to/open' + globalThis.core = { + api: { + openFileExplorer: jest.fn().mockResolvedValue('opened'), + }, } - }; - const result = await openFileExplorer(path); - expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path); - expect(result).toBe('opened'); -}); - + const result = await openFileExplorer(path) + expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path) + expect(result).toBe('opened') + }) -it('should get jan data folder path', async () => { - globalThis.core = { - api: { - getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data') + it('should get jan data folder path', async () => { + globalThis.core = { + api: { + getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data'), + }, } - }; - const result = await getJanDataFolderPath(); - expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled(); - expect(result).toBe('/path/to/jan/data'); -}); + const result = await getJanDataFolderPath() + expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled() + expect(result).toBe('/path/to/jan/data') + }) - -it('should abort download', async () => { - const fileName = 'testFile'; - globalThis.core = { - api: { - abortDownload: jest.fn().mockResolvedValue('aborted') + it('should abort download', async () => { + const fileName = 'testFile' + globalThis.core = { + api: { + abortDownload: jest.fn().mockResolvedValue('aborted'), + }, } - }; - const result = await abortDownload(fileName); - expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName); - expect(result).toBe('aborted'); -}); - + const result = await abortDownload(fileName) + expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName) + expect(result).toBe('aborted') + }) -it('should get file size', async () => { - const url = 'http://example.com/file'; - globalThis.core = { - api: { - getFileSize: jest.fn().mockResolvedValue(1024) + it('should get file size', async () => { + const url = 'http://example.com/file' + globalThis.core = { + api: { + getFileSize: jest.fn().mockResolvedValue(1024), + }, } - }; - const result = await getFileSize(url); - expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url); - expect(result).toBe(1024); -}); + const result = await getFileSize(url) + expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url) + expect(result).toBe(1024) + }) + it('should execute function on main process', async () => { + const extension = 'testExtension' + const method = 'testMethod' + const args = ['arg1', 'arg2'] + globalThis.core = { + api: { + invokeExtensionFunc: jest.fn().mockResolvedValue('result'), + }, + } + const result = await executeOnMain(extension, method, ...args) + expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args) + expect(result).toBe('result') + }) +}) -it('should execute function on main process', async () => { - const extension = 'testExtension'; - const method = 'testMethod'; - const args = ['arg1', 'arg2']; - globalThis.core = { - api: { - invokeExtensionFunc: jest.fn().mockResolvedValue('result') +describe('dirName - just a pass thru api', () => { + it('should retrieve the directory name from a file path', async () => { + const mockDirName = jest.fn() + globalThis.core = { + api: { + dirName: mockDirName.mockResolvedValue('/path/to'), + }, } - }; - const result = await executeOnMain(extension, method, ...args); - expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args); - expect(result).toBe('result'); -}); + // Normal file path with extension + const path = '/path/to/file.txt' + await globalThis.core.api.dirName(path) + expect(mockDirName).toHaveBeenCalledWith(path) + }) +}) diff --git a/core/src/browser/core.ts b/core/src/browser/core.ts index fdbceb06bb..b19e0b339f 100644 --- a/core/src/browser/core.ts +++ b/core/src/browser/core.ts @@ -68,6 +68,13 @@ const openFileExplorer: (path: string) => Promise = (path) => const joinPath: (paths: string[]) => Promise = (paths) => globalThis.core.api?.joinPath(paths) +/** + * Get dirname of a file path. + * @param path - The file path to retrieve dirname. + * @returns {Promise} A promise that resolves the dirname. + */ +const dirName: (path: string) => Promise = (path) => globalThis.core.api?.dirName(path) + /** * Retrieve the basename from an url. * @param path - The path to retrieve. @@ -161,5 +168,6 @@ export { systemInformation, showToast, getFileSize, + dirName, FileStat, } diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 7cd9f513e2..75354de88e 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -2,7 +2,7 @@ import { getJanDataFolderPath, joinPath } from '../../core' import { events } from '../../events' import { BaseExtension } from '../../extension' import { fs } from '../../fs' -import { MessageRequest, Model, ModelEvent } from '../../../types' +import { MessageRequest, Model, ModelEvent, ModelFile } from '../../../types' import { EngineManager } from './EngineManager' /** @@ -21,7 +21,7 @@ export abstract class AIEngine extends BaseExtension { override onLoad() { this.registerEngine() - events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model)) events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) } @@ -78,7 +78,7 @@ export abstract class AIEngine extends BaseExtension { /** * Loads the model. */ - async loadModel(model: Model): Promise { + async loadModel(model: ModelFile): Promise { if (model.engine.toString() !== this.provider) return Promise.resolve() events.emit(ModelEvent.OnModelReady, model) return Promise.resolve() diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.ts b/core/src/browser/extensions/engines/LocalOAIEngine.ts index fb9e4962c4..123b9a5930 100644 --- a/core/src/browser/extensions/engines/LocalOAIEngine.ts +++ b/core/src/browser/extensions/engines/LocalOAIEngine.ts @@ -1,6 +1,6 @@ -import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core' +import { executeOnMain, systemInformation, dirName } from '../../core' import { events } from '../../events' -import { Model, ModelEvent } from '../../../types' +import { Model, ModelEvent, ModelFile } from '../../../types' import { OAIEngine } from './OAIEngine' /** @@ -14,22 +14,24 @@ export abstract class LocalOAIEngine extends OAIEngine { unloadModelFunctionName: string = 'unloadModel' /** - * On extension load, subscribe to events. + * This class represents a base for local inference providers in the OpenAI architecture. + * It extends the OAIEngine class and provides the implementation of loading and unloading models locally. + * The loadModel function subscribes to the ModelEvent.OnModelInit event, loading models when initiated. + * The unloadModel function subscribes to the ModelEvent.OnModelStop event, unloading models when stopped. */ override onLoad() { super.onLoad() // These events are applicable to local inference providers - events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model)) events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) } /** * Load the model. */ - override async loadModel(model: Model): Promise { + override async loadModel(model: ModelFile): Promise { if (model.engine.toString() !== this.provider) return - const modelFolderName = 'models' - const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id]) + const modelFolder = await dirName(model.file_path) const systemInfo = await systemInformation() const res = await executeOnMain( this.nodeModule, diff --git a/core/src/browser/extensions/model.ts b/core/src/browser/extensions/model.ts index 5b3089403f..040542927d 100644 --- a/core/src/browser/extensions/model.ts +++ b/core/src/browser/extensions/model.ts @@ -4,6 +4,7 @@ import { HuggingFaceRepoData, ImportingModel, Model, + ModelFile, ModelInterface, OptionType, } from '../../types' @@ -25,12 +26,11 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter network?: { proxy: string; ignoreSSL?: boolean } ): Promise abstract cancelModelDownload(modelId: string): Promise - abstract deleteModel(modelId: string): Promise - abstract saveModel(model: Model): Promise - abstract getDownloadedModels(): Promise - abstract getConfiguredModels(): Promise + abstract deleteModel(model: ModelFile): Promise + abstract getDownloadedModels(): Promise + abstract getConfiguredModels(): Promise abstract importModels(models: ImportingModel[], optionType: OptionType): Promise - abstract updateModelInfo(modelInfo: Partial): Promise + abstract updateModelInfo(modelInfo: Partial): Promise abstract fetchHuggingFaceRepoData(repoId: string): Promise abstract getDefaultModel(): Promise } diff --git a/core/src/node/api/processors/app.test.ts b/core/src/node/api/processors/app.test.ts index 3ada5df1e5..5c4daef29d 100644 --- a/core/src/node/api/processors/app.test.ts +++ b/core/src/node/api/processors/app.test.ts @@ -1,40 +1,57 @@ -import { App } from './app'; +jest.mock('../../helper', () => ({ + ...jest.requireActual('../../helper'), + getJanDataFolderPath: () => './app', +})) +import { dirname } from 'path' +import { App } from './app' it('should call stopServer', () => { - const app = new App(); - const stopServerMock = jest.fn().mockResolvedValue('Server stopped'); + const app = new App() + const stopServerMock = jest.fn().mockResolvedValue('Server stopped') jest.mock('@janhq/server', () => ({ - stopServer: stopServerMock - })); - const result = app.stopServer(); - expect(stopServerMock).toHaveBeenCalled(); -}); + stopServer: stopServerMock, + })) + app.stopServer() + expect(stopServerMock).toHaveBeenCalled() +}) it('should correctly retrieve basename', () => { - const app = new App(); - const result = app.baseName('/path/to/file.txt'); - expect(result).toBe('file.txt'); -}); + const app = new App() + const result = app.baseName('/path/to/file.txt') + expect(result).toBe('file.txt') +}) it('should correctly identify subdirectories', () => { - const app = new App(); - const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to'; - const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir'; - const result = app.isSubdirectory(basePath, subPath); - expect(result).toBe(true); -}); + const app = new App() + const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to' + const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir' + const result = app.isSubdirectory(basePath, subPath) + expect(result).toBe(true) +}) it('should correctly join multiple paths', () => { - const app = new App(); - const result = app.joinPath(['path', 'to', 'file']); - const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file'; - expect(result).toBe(expectedPath); -}); + const app = new App() + const result = app.joinPath(['path', 'to', 'file']) + const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file' + expect(result).toBe(expectedPath) +}) it('should call correct function with provided arguments using process method', () => { - const app = new App(); - const mockFunc = jest.fn(); - app.joinPath = mockFunc; - app.process('joinPath', ['path1', 'path2']); - expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']); -}); + const app = new App() + const mockFunc = jest.fn() + app.joinPath = mockFunc + app.process('joinPath', ['path1', 'path2']) + expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']) +}) + +it('should retrieve the directory name from a file path (Unix/Windows)', async () => { + const app = new App() + const path = 'C:/Users/John Doe/Desktop/file.txt' + expect(await app.dirName(path)).toBe('C:/Users/John Doe/Desktop') +}) + +it('should retrieve the directory name when using file protocol', async () => { + const app = new App() + const path = 'file:/models/file.txt' + expect(await app.dirName(path)).toBe(process.platform === 'win32' ? 'app\\models' : 'app/models') +}) diff --git a/core/src/node/api/processors/app.ts b/core/src/node/api/processors/app.ts index 15460ba565..a0808c5ac6 100644 --- a/core/src/node/api/processors/app.ts +++ b/core/src/node/api/processors/app.ts @@ -1,4 +1,4 @@ -import { basename, isAbsolute, join, relative } from 'path' +import { basename, dirname, isAbsolute, join, relative } from 'path' import { Processor } from './Processor' import { @@ -6,6 +6,8 @@ import { appResourcePath, getAppConfigurations as appConfiguration, updateAppConfiguration, + normalizeFilePath, + getJanDataFolderPath, } from '../../helper' export class App implements Processor { @@ -28,6 +30,18 @@ export class App implements Processor { return join(...args) } + /** + * Get dirname of a file path. + * @param path - The file path to retrieve dirname. + */ + dirName(path: string) { + const arg = + path.startsWith(`file:/`) || path.startsWith(`file:\\`) + ? join(getJanDataFolderPath(), normalizeFilePath(path)) + : path + return dirname(arg) + } + /** * Checks if the given path is a subdirectory of the given directory. * diff --git a/core/src/types/api/index.ts b/core/src/types/api/index.ts index bca11c0a89..8f1ff70bf9 100644 --- a/core/src/types/api/index.ts +++ b/core/src/types/api/index.ts @@ -37,6 +37,7 @@ export enum AppRoute { getAppConfigurations = 'getAppConfigurations', updateAppConfiguration = 'updateAppConfiguration', joinPath = 'joinPath', + dirName = 'dirName', isSubdirectory = 'isSubdirectory', baseName = 'baseName', startServer = 'startServer', diff --git a/core/src/types/file/index.ts b/core/src/types/file/index.ts index 1b36a5777d..4db956b1e7 100644 --- a/core/src/types/file/index.ts +++ b/core/src/types/file/index.ts @@ -52,3 +52,18 @@ type DownloadSize = { total: number transferred: number } + +/** + * The file metadata + */ +export type FileMetadata = { + /** + * The origin file path. + */ + file_path: string + + /** + * The file name. + */ + file_name: string +} diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index f154f7f04b..933c698c39 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -1,3 +1,5 @@ +import { FileMetadata } from '../file' + /** * Represents the information about a model. * @stored @@ -151,3 +153,8 @@ export type ModelRuntimeParams = { export type ModelInitFailed = Model & { error: Error } + +/** + * ModelFile is the model.json entity and it's file metadata + */ +export type ModelFile = Model & FileMetadata diff --git a/core/src/types/model/modelInterface.ts b/core/src/types/model/modelInterface.ts index 639c7c8d34..5b5856231e 100644 --- a/core/src/types/model/modelInterface.ts +++ b/core/src/types/model/modelInterface.ts @@ -1,5 +1,5 @@ import { GpuSetting } from '../miscellaneous' -import { Model } from './modelEntity' +import { Model, ModelFile } from './modelEntity' /** * Model extension for managing models. @@ -29,14 +29,7 @@ export interface ModelInterface { * @param modelId - The ID of the model to delete. * @returns A Promise that resolves when the model has been deleted. */ - deleteModel(modelId: string): Promise - - /** - * Saves a model. - * @param model - The model to save. - * @returns A Promise that resolves when the model has been saved. - */ - saveModel(model: Model): Promise + deleteModel(model: ModelFile): Promise /** * Gets a list of downloaded models. diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index d79e076d4e..6e825e8fd1 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -22,6 +22,7 @@ import { downloadFile, DownloadState, DownloadEvent, + ModelFile, } from '@janhq/core' declare const CUDA_DOWNLOAD_URL: string @@ -94,7 +95,7 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine { this.nitroProcessInfo = health } - override loadModel(model: Model): Promise { + override loadModel(model: ModelFile): Promise { if (model.engine !== this.provider) return Promise.resolve() this.getNitroProcessHealthIntervalId = setInterval( () => this.periodicallyGetNitroHealth(), diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts index 3a969ad5e7..98ca4572fa 100644 --- a/extensions/inference-nitro-extension/src/node/index.ts +++ b/extensions/inference-nitro-extension/src/node/index.ts @@ -6,12 +6,12 @@ import fetchRT from 'fetch-retry' import { log, getSystemResourceInfo, - Model, InferenceEngine, ModelSettingParams, PromptTemplate, SystemInformation, getJanDataFolderPath, + ModelFile, } from '@janhq/core/node' import { executableNitroFile } from './execute' import terminate from 'terminate' @@ -25,7 +25,7 @@ const fetchRetry = fetchRT(fetch) */ interface ModelInitOptions { modelFolder: string - model: Model + model: ModelFile } // The PORT to use for the Nitro subprocess const PORT = 3928 diff --git a/extensions/model-extension/jest.config.js b/extensions/model-extension/jest.config.js new file mode 100644 index 0000000000..3e32adceb2 --- /dev/null +++ b/extensions/model-extension/jest.config.js @@ -0,0 +1,9 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', + transform: { + 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest', + }, + transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'], +} diff --git a/extensions/model-extension/package.json b/extensions/model-extension/package.json index 4a2c61b716..9a406dcf42 100644 --- a/extensions/model-extension/package.json +++ b/extensions/model-extension/package.json @@ -8,6 +8,7 @@ "author": "Jan ", "license": "AGPL-3.0", "scripts": { + "test": "jest", "build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs", "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" }, diff --git a/extensions/model-extension/rollup.config.ts b/extensions/model-extension/rollup.config.ts index c3f3acc77a..d36d8ffacd 100644 --- a/extensions/model-extension/rollup.config.ts +++ b/extensions/model-extension/rollup.config.ts @@ -27,7 +27,7 @@ export default [ // Allow json resolution json(), // Compile TypeScript files - typescript({ useTsconfigDeclarationDir: true }), + typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }), // Compile TypeScript files // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) // commonjs(), @@ -62,7 +62,7 @@ export default [ // Allow json resolution json(), // Compile TypeScript files - typescript({ useTsconfigDeclarationDir: true }), + typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }), // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) commonjs(), // Allow node_modules resolution, so you can use 'external' to control diff --git a/extensions/model-extension/src/index.test.ts b/extensions/model-extension/src/index.test.ts new file mode 100644 index 0000000000..6816d7101a --- /dev/null +++ b/extensions/model-extension/src/index.test.ts @@ -0,0 +1,564 @@ +const readDirSyncMock = jest.fn() +const existMock = jest.fn() +const readFileSyncMock = jest.fn() + +jest.mock('@janhq/core', () => ({ + ...jest.requireActual('@janhq/core/node'), + fs: { + existsSync: existMock, + readdirSync: readDirSyncMock, + readFileSync: readFileSyncMock, + fileStat: () => ({ + isDirectory: false, + }), + }, + dirName: jest.fn(), + joinPath: (paths) => paths.join('/'), + ModelExtension: jest.fn(), +})) + +import JanModelExtension from '.' +import { fs, dirName } from '@janhq/core' + +describe('JanModelExtension', () => { + let sut: JanModelExtension + + beforeAll(() => { + // @ts-ignore + sut = new JanModelExtension() + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + describe('getConfiguredModels', () => { + describe("when there's no models are pre-populated", () => { + it('should return empty array', async () => { + // Mock configured models data + const configuredModels = [] + existMock.mockReturnValue(true) + readDirSyncMock.mockReturnValue([]) + + const result = await sut.getConfiguredModels() + expect(result).toEqual([]) + }) + }) + + describe("when there's are pre-populated models - all flattened", () => { + it('returns configured models data - flatten folder - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2'] + else return ['model.json'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getConfiguredModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model.json', + id: '2', + }), + ]) + ) + }) + }) + + describe("when there's are pre-populated models - there are nested folders", () => { + it('returns configured models data - flatten folder - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else return ['model.json'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else if (path.includes('model2/model2-1')) + return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getConfiguredModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + }) + + describe('getDownloadedModels', () => { + describe('no models downloaded', () => { + it('should return empty array', async () => { + // Mock downloaded models data + const downloadedModels = [] + existMock.mockReturnValue(true) + readDirSyncMock.mockReturnValue([]) + + const result = await sut.getDownloadedModels() + expect(result).toEqual([]) + }) + }) + describe('only one model is downloaded', () => { + describe('flatten folder', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2'] + else if (path === 'file://models/model1') + return ['model.json', 'test.gguf'] + else return ['model.json'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + ]) + ) + }) + }) + }) + + describe('all models are downloaded', () => { + describe('nested folders', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else return ['model.json', 'test.gguf'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + }) + + describe('all models are downloaded with uppercased GGUF files', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else if (path === 'file://models/model1') + return ['model.json', 'test.GGUF'] + else return ['model.json', 'test.gguf'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + + describe('all models are downloaded - GGUF & Tensort RT', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else if (path === 'file://models/model1') + return ['model.json', 'test.gguf'] + else return ['model.json', 'test.engine'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + }) + + describe('deleteModel', () => { + describe('model is a GGUF model', () => { + it('should delete the GGUF file', async () => { + fs.unlinkSync = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({})) + + readDirSyncMock.mockImplementation((path) => { + return ['model.json', 'test.gguf'] + }) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.unlinkSync).toHaveBeenCalledWith( + 'file://models/model1/test.gguf' + ) + }) + + it('no gguf file presented', async () => { + fs.unlinkSync = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({})) + + readDirSyncMock.mockReturnValue(['model.json']) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.unlinkSync).toHaveBeenCalledTimes(0) + }) + + it('delete an imported model', async () => { + fs.rm = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + readDirSyncMock.mockReturnValue(['model.json', 'test.gguf']) + + // MARK: This is a tricky logic implement? + // I will just add test for now but will align on the legacy implementation + fs.readFileSync = jest.fn().mockReturnValue( + JSON.stringify({ + metadata: { + author: 'user', + }, + }) + ) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.rm).toHaveBeenCalledWith('file://models/model1') + }) + + it('delete tensorrt-models', async () => { + fs.rm = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + readDirSyncMock.mockReturnValue(['model.json', 'test.engine']) + + fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({})) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine') + }) + }) + }) +}) diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index e2f68a58c7..ac9b06a095 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -22,6 +22,8 @@ import { getFileSize, AllQuantizations, ModelEvent, + ModelFile, + dirName, } from '@janhq/core' import { extractFileName } from './helpers/path' @@ -48,16 +50,7 @@ export default class JanModelExtension extends ModelExtension { ] private static readonly _tensorRtEngineFormat = '.engine' private static readonly _supportedGpuArch = ['ampere', 'ada'] - private static readonly _safetensorsRegexs = [ - /model\.safetensors$/, - /model-[0-9]+-of-[0-9]+\.safetensors$/, - ] - private static readonly _pytorchRegexs = [ - /pytorch_model\.bin$/, - /consolidated\.[0-9]+\.pth$/, - /pytorch_model-[0-9]+-of-[0-9]+\.bin$/, - /.*\.pt$/, - ] + interrupted = false /** @@ -319,9 +312,9 @@ export default class JanModelExtension extends ModelExtension { * @param filePath - The path to the model file to delete. * @returns A Promise that resolves when the model is deleted. */ - async deleteModel(modelId: string): Promise { + async deleteModel(model: ModelFile): Promise { try { - const dirPath = await joinPath([JanModelExtension._homeDir, modelId]) + const dirPath = await dirName(model.file_path) const jsonFilePath = await joinPath([ dirPath, JanModelExtension._modelMetadataFileName, @@ -330,9 +323,11 @@ export default class JanModelExtension extends ModelExtension { await this.readModelMetadata(jsonFilePath) ) as Model + // TODO: This is so tricky? + // Should depend on sources? const isUserImportModel = modelInfo.metadata?.author?.toLowerCase() === 'user' - if (isUserImportModel) { + if (isUserImportModel) { // just delete the folder return fs.rm(dirPath) } @@ -350,30 +345,11 @@ export default class JanModelExtension extends ModelExtension { } } - /** - * Saves a model file. - * @param model - The model to save. - * @returns A Promise that resolves when the model is saved. - */ - async saveModel(model: Model): Promise { - const jsonFilePath = await joinPath([ - JanModelExtension._homeDir, - model.id, - JanModelExtension._modelMetadataFileName, - ]) - - try { - await fs.writeFileSync(jsonFilePath, JSON.stringify(model, null, 2)) - } catch (err) { - console.error(err) - } - } - /** * Gets all downloaded models. * @returns A Promise that resolves with an array of all models. */ - async getDownloadedModels(): Promise { + async getDownloadedModels(): Promise { return await this.getModelsMetadata( async (modelDir: string, model: Model) => { if (!JanModelExtension._offlineInferenceEngine.includes(model.engine)) @@ -425,8 +401,10 @@ export default class JanModelExtension extends ModelExtension { ): Promise { // try to find model.json recursively inside each folder if (!(await fs.existsSync(folderFullPath))) return undefined + const files: string[] = await fs.readdirSync(folderFullPath) if (files.length === 0) return undefined + if (files.includes(JanModelExtension._modelMetadataFileName)) { return joinPath([ folderFullPath, @@ -446,7 +424,7 @@ export default class JanModelExtension extends ModelExtension { private async getModelsMetadata( selector?: (path: string, model: Model) => Promise - ): Promise { + ): Promise { try { if (!(await fs.existsSync(JanModelExtension._homeDir))) { console.debug('Model folder not found') @@ -469,6 +447,7 @@ export default class JanModelExtension extends ModelExtension { JanModelExtension._homeDir, dirName, ]) + const jsonPath = await this.getModelJsonPath(folderFullPath) if (await fs.existsSync(jsonPath)) { @@ -486,6 +465,8 @@ export default class JanModelExtension extends ModelExtension { }, ] } + model.file_path = jsonPath + model.file_name = JanModelExtension._modelMetadataFileName if (selector && !(await selector?.(dirName, model))) { return @@ -506,7 +487,7 @@ export default class JanModelExtension extends ModelExtension { typeof result.value === 'object' ? result.value : JSON.parse(result.value) - return model as Model + return model as ModelFile } catch { console.debug(`Unable to parse model metadata: ${result.value}`) } @@ -637,7 +618,7 @@ export default class JanModelExtension extends ModelExtension { * Gets all available models. * @returns A Promise that resolves with an array of all models. */ - async getConfiguredModels(): Promise { + async getConfiguredModels(): Promise { return this.getModelsMetadata() } @@ -669,7 +650,7 @@ export default class JanModelExtension extends ModelExtension { modelBinaryPath: string, modelFolderName: string, modelFolderPath: string - ): Promise { + ): Promise { const fileStats = await fs.fileStat(modelBinaryPath, true) const binaryFileSize = fileStats.size @@ -732,25 +713,21 @@ export default class JanModelExtension extends ModelExtension { await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2)) - return model + return { + ...model, + file_path: modelFilePath, + file_name: JanModelExtension._modelMetadataFileName, + } } - async updateModelInfo(modelInfo: Partial): Promise { - const modelId = modelInfo.id + async updateModelInfo(modelInfo: Partial): Promise { if (modelInfo.id == null) throw new Error('Model ID is required') - const janDataFolderPath = await getJanDataFolderPath() - const jsonFilePath = await joinPath([ - janDataFolderPath, - 'models', - modelId, - JanModelExtension._modelMetadataFileName, - ]) const model = JSON.parse( - await this.readModelMetadata(jsonFilePath) - ) as Model + await this.readModelMetadata(modelInfo.file_path) + ) as ModelFile - const updatedModel: Model = { + const updatedModel: ModelFile = { ...model, ...modelInfo, parameters: { @@ -765,9 +742,15 @@ export default class JanModelExtension extends ModelExtension { ...model.metadata, ...modelInfo.metadata, }, + // Should not persist file_path & file_name + file_path: undefined, + file_name: undefined, } - await fs.writeFileSync(jsonFilePath, JSON.stringify(updatedModel, null, 2)) + await fs.writeFileSync( + modelInfo.file_path, + JSON.stringify(updatedModel, null, 2) + ) return updatedModel } diff --git a/extensions/model-extension/tsconfig.json b/extensions/model-extension/tsconfig.json index addd8e1274..0d32529346 100644 --- a/extensions/model-extension/tsconfig.json +++ b/extensions/model-extension/tsconfig.json @@ -10,5 +10,6 @@ "skipLibCheck": true, "rootDir": "./src" }, - "include": ["./src"] + "include": ["./src"], + "exclude": ["**/*.test.ts"] } diff --git a/extensions/tensorrt-llm-extension/src/index.ts b/extensions/tensorrt-llm-extension/src/index.ts index 189abc706a..7f68c43bd9 100644 --- a/extensions/tensorrt-llm-extension/src/index.ts +++ b/extensions/tensorrt-llm-extension/src/index.ts @@ -23,6 +23,7 @@ import { ModelEvent, getJanDataFolderPath, SystemInformation, + ModelFile, } from '@janhq/core' /** @@ -137,7 +138,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { events.emit(ModelEvent.OnModelsUpdate, {}) } - override async loadModel(model: Model): Promise { + override async loadModel(model: ModelFile): Promise { if ((await this.installationState()) === 'Installed') return super.loadModel(model) diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx index 92d8addd08..d8743ddce9 100644 --- a/web/containers/ModelDropdown/index.tsx +++ b/web/containers/ModelDropdown/index.tsx @@ -46,7 +46,6 @@ import { import { extensionManager } from '@/extension' -import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom' import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom' import { configuredModelsAtom, @@ -91,8 +90,6 @@ const ModelDropdown = ({ const featuredModel = configuredModels.filter((x) => x.metadata.tags.includes('Featured') ) - const preserveModelSettings = useAtomValue(preserveModelSettingsAtom) - const { updateThreadMetadata } = useCreateNewThread() useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [ @@ -191,27 +188,14 @@ const ModelDropdown = ({ ], }) - // Default setting ctx_len for the model for a better onboarding experience - // TODO: When Cortex support hardware instructions, we should remove this - const defaultContextLength = preserveModelSettings - ? model?.metadata?.default_ctx_len - : 2048 - const defaultMaxTokens = preserveModelSettings - ? model?.metadata?.default_max_tokens - : 2048 const overriddenSettings = - model?.settings.ctx_len && model.settings.ctx_len > 2048 - ? { ctx_len: defaultContextLength ?? 2048 } - : {} - const overriddenParameters = - model?.parameters.max_tokens && model.parameters.max_tokens - ? { max_tokens: defaultMaxTokens ?? 2048 } + model?.settings.ctx_len && model.settings.ctx_len > 4096 + ? { ctx_len: 4096 } : {} const modelParams = { ...model?.parameters, ...model?.settings, - ...overriddenParameters, ...overriddenSettings, } @@ -222,6 +206,7 @@ const ModelDropdown = ({ if (model) updateModelParameter(activeThread, { params: modelParams, + modelPath: model.file_path, modelId: model.id, engine: model.engine, }) @@ -235,7 +220,6 @@ const ModelDropdown = ({ setThreadModelParams, updateModelParameter, updateThreadMetadata, - preserveModelSettings, ] ) diff --git a/web/helpers/atoms/AppConfig.atom.ts b/web/helpers/atoms/AppConfig.atom.ts index e7b7efaecd..f4acc7dc22 100644 --- a/web/helpers/atoms/AppConfig.atom.ts +++ b/web/helpers/atoms/AppConfig.atom.ts @@ -7,7 +7,6 @@ const VULKAN_ENABLED = 'vulkanEnabled' const IGNORE_SSL = 'ignoreSSLFeature' const HTTPS_PROXY_FEATURE = 'httpsProxyFeature' const QUICK_ASK_ENABLED = 'quickAskEnabled' -const PRESERVE_MODEL_SETTINGS = 'preserveModelSettings' export const janDataFolderPathAtom = atom('') @@ -24,9 +23,3 @@ export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false) export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false) export const hostAtom = atom('http://localhost:1337/') - -// This feature is to allow user to cache model settings on thread creation -export const preserveModelSettingsAtom = atomWithStorage( - PRESERVE_MODEL_SETTINGS, - false -) diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts index 77b1bfa4e4..d2d0ca9f41 100644 --- a/web/helpers/atoms/Model.atom.ts +++ b/web/helpers/atoms/Model.atom.ts @@ -1,4 +1,4 @@ -import { ImportingModel, Model, InferenceEngine } from '@janhq/core' +import { ImportingModel, Model, InferenceEngine, ModelFile } from '@janhq/core' import { atom } from 'jotai' import { localEngines } from '@/utils/modelEngine' @@ -32,18 +32,7 @@ export const removeDownloadingModelAtom = atom( } ) -export const downloadedModelsAtom = atom([]) - -export const updateDownloadedModelAtom = atom( - null, - (get, set, updatedModel: Model) => { - const models: Model[] = get(downloadedModelsAtom).map((c) => - c.id === updatedModel.id ? updatedModel : c - ) - - set(downloadedModelsAtom, models) - } -) +export const downloadedModelsAtom = atom([]) export const removeDownloadedModelAtom = atom( null, @@ -57,7 +46,7 @@ export const removeDownloadedModelAtom = atom( } ) -export const configuredModelsAtom = atom([]) +export const configuredModelsAtom = atom([]) export const defaultModelAtom = atom(undefined) @@ -144,6 +133,6 @@ export const updateImportingModelAtom = atom( } ) -export const selectedModelAtom = atom(undefined) +export const selectedModelAtom = atom(undefined) export const showEngineListModelAtom = atom(localEngines) diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index 9768ac4c4a..2d53678c31 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -1,6 +1,6 @@ import { useCallback, useEffect, useRef } from 'react' -import { EngineManager, Model } from '@janhq/core' +import { EngineManager, Model, ModelFile } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { toaster } from '@/containers/Toast' @@ -11,7 +11,7 @@ import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' -export const activeModelAtom = atom(undefined) +export const activeModelAtom = atom(undefined) export const loadModelErrorAtom = atom(undefined) type ModelState = { @@ -37,7 +37,7 @@ export function useActiveModel() { const [pendingModelLoad, setPendingModelLoad] = useAtom(pendingModelLoadAtom) const isVulkanEnabled = useAtomValue(vulkanEnabledAtom) - const downloadedModelsRef = useRef([]) + const downloadedModelsRef = useRef([]) useEffect(() => { downloadedModelsRef.current = downloadedModels diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts index 80acfa3cca..5548259fd8 100644 --- a/web/hooks/useCreateNewThread.ts +++ b/web/hooks/useCreateNewThread.ts @@ -7,8 +7,8 @@ import { Thread, ThreadAssistantInfo, ThreadState, - Model, AssistantTool, + ModelFile, } from '@janhq/core' import { atom, useAtomValue, useSetAtom } from 'jotai' @@ -26,10 +26,7 @@ import useSetActiveThread from './useSetActiveThread' import { extensionManager } from '@/extension' -import { - experimentalFeatureEnabledAtom, - preserveModelSettingsAtom, -} from '@/helpers/atoms/AppConfig.atom' +import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { threadsAtom, @@ -67,7 +64,6 @@ export const useCreateNewThread = () => { const copyOverInstructionEnabled = useAtomValue( copyOverInstructionEnabledAtom ) - const preserveModelSettings = useAtomValue(preserveModelSettingsAtom) const activeThread = useAtomValue(activeThreadAtom) const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom) @@ -80,7 +76,7 @@ export const useCreateNewThread = () => { const requestCreateNewThread = async ( assistant: Assistant, - model?: Model | undefined + model?: ModelFile | undefined ) => { // Stop generating if any setIsGeneratingResponse(false) @@ -109,19 +105,13 @@ export const useCreateNewThread = () => { enabled: true, settings: assistant.tools && assistant.tools[0].settings, } - const defaultContextLength = preserveModelSettings - ? defaultModel?.metadata?.default_ctx_len - : 2048 - const defaultMaxTokens = preserveModelSettings - ? defaultModel?.metadata?.default_max_tokens - : 2048 const overriddenSettings = defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048 - ? { ctx_len: defaultContextLength ?? 2048 } + ? { ctx_len: 4096 } : {} const overriddenParameters = defaultModel?.parameters.max_tokens - ? { max_tokens: defaultMaxTokens ?? 2048 } + ? { max_tokens: 4096 } : {} const createdAt = Date.now() diff --git a/web/hooks/useDeleteModel.ts b/web/hooks/useDeleteModel.ts index 9736f82563..5a7a319b2e 100644 --- a/web/hooks/useDeleteModel.ts +++ b/web/hooks/useDeleteModel.ts @@ -1,6 +1,6 @@ import { useCallback } from 'react' -import { ExtensionTypeEnum, ModelExtension, Model } from '@janhq/core' +import { ExtensionTypeEnum, ModelExtension, ModelFile } from '@janhq/core' import { useSetAtom } from 'jotai' @@ -13,8 +13,8 @@ export default function useDeleteModel() { const removeDownloadedModel = useSetAtom(removeDownloadedModelAtom) const deleteModel = useCallback( - async (model: Model) => { - await localDeleteModel(model.id) + async (model: ModelFile) => { + await localDeleteModel(model) removeDownloadedModel(model.id) toaster({ title: 'Model Deletion Successful', @@ -28,5 +28,7 @@ export default function useDeleteModel() { return { deleteModel } } -const localDeleteModel = async (id: string) => - extensionManager.get(ExtensionTypeEnum.Model)?.deleteModel(id) +const localDeleteModel = async (model: ModelFile) => + extensionManager + .get(ExtensionTypeEnum.Model) + ?.deleteModel(model) diff --git a/web/hooks/useModels.ts b/web/hooks/useModels.ts index 5a6f13e035..8333c35c35 100644 --- a/web/hooks/useModels.ts +++ b/web/hooks/useModels.ts @@ -5,6 +5,7 @@ import { Model, ModelEvent, ModelExtension, + ModelFile, events, } from '@janhq/core' @@ -63,12 +64,12 @@ const getLocalDefaultModel = async (): Promise => .get(ExtensionTypeEnum.Model) ?.getDefaultModel() -const getLocalConfiguredModels = async (): Promise => +const getLocalConfiguredModels = async (): Promise => extensionManager .get(ExtensionTypeEnum.Model) ?.getConfiguredModels() ?? [] -const getLocalDownloadedModels = async (): Promise => +const getLocalDownloadedModels = async (): Promise => extensionManager .get(ExtensionTypeEnum.Model) ?.getDownloadedModels() ?? [] diff --git a/web/hooks/useRecommendedModel.ts b/web/hooks/useRecommendedModel.ts index 21a9c69e72..ed56efa552 100644 --- a/web/hooks/useRecommendedModel.ts +++ b/web/hooks/useRecommendedModel.ts @@ -1,6 +1,6 @@ import { useCallback, useEffect, useState } from 'react' -import { Model, InferenceEngine } from '@janhq/core' +import { Model, InferenceEngine, ModelFile } from '@janhq/core' import { atom, useAtomValue } from 'jotai' @@ -24,12 +24,16 @@ export const LAST_USED_MODEL_ID = 'last-used-model-id' */ export default function useRecommendedModel() { const activeModel = useAtomValue(activeModelAtom) - const [sortedModels, setSortedModels] = useState([]) - const [recommendedModel, setRecommendedModel] = useState() + const [sortedModels, setSortedModels] = useState([]) + const [recommendedModel, setRecommendedModel] = useState< + ModelFile | undefined + >() const activeThread = useAtomValue(activeThreadAtom) const downloadedModels = useAtomValue(downloadedModelsAtom) - const getAndSortDownloadedModels = useCallback(async (): Promise => { + const getAndSortDownloadedModels = useCallback(async (): Promise< + ModelFile[] + > => { const models = downloadedModels.sort((a, b) => a.engine !== InferenceEngine.nitro && b.engine === InferenceEngine.nitro ? 1 diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts index 46bf07cd50..af30210adc 100644 --- a/web/hooks/useUpdateModelParameters.ts +++ b/web/hooks/useUpdateModelParameters.ts @@ -4,8 +4,6 @@ import { ConversationalExtension, ExtensionTypeEnum, InferenceEngine, - Model, - ModelExtension, Thread, ThreadAssistantInfo, } from '@janhq/core' @@ -17,14 +15,8 @@ import { extractModelLoadParams, } from '@/utils/modelParam' -import useRecommendedModel from './useRecommendedModel' - import { extensionManager } from '@/extension' -import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom' -import { - selectedModelAtom, - updateDownloadedModelAtom, -} from '@/helpers/atoms/Model.atom' +import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { ModelParams, getActiveThreadModelParamsAtom, @@ -34,16 +26,14 @@ import { export type UpdateModelParameter = { params?: ModelParams modelId?: string + modelPath?: string engine?: InferenceEngine } export default function useUpdateModelParameters() { const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) - const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) + const [selectedModel] = useAtom(selectedModelAtom) const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) - const updateDownloadedModel = useSetAtom(updateDownloadedModelAtom) - const preserveModelFeatureEnabled = useAtomValue(preserveModelSettingsAtom) - const { recommendedModel, setRecommendedModel } = useRecommendedModel() const updateModelParameter = useCallback( async (thread: Thread, settings: UpdateModelParameter) => { @@ -83,50 +73,8 @@ export default function useUpdateModelParameters() { await extensionManager .get(ExtensionTypeEnum.Conversational) ?.saveThread(updatedThread) - - // Persists default settings to model file - // Do not overwrite ctx_len and max_tokens - if (preserveModelFeatureEnabled) { - const defaultContextLength = settingParams.ctx_len - const defaultMaxTokens = runtimeParams.max_tokens - - // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars - const { ctx_len, ...toSaveSettings } = settingParams - // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars - const { max_tokens, ...toSaveParams } = runtimeParams - - const updatedModel = { - id: settings.modelId ?? selectedModel?.id, - parameters: { - ...toSaveSettings, - }, - settings: { - ...toSaveParams, - }, - metadata: { - default_ctx_len: defaultContextLength, - default_max_tokens: defaultMaxTokens, - }, - } as Partial - - const model = await extensionManager - .get(ExtensionTypeEnum.Model) - ?.updateModelInfo(updatedModel) - if (model) updateDownloadedModel(model) - if (selectedModel?.id === model?.id) setSelectedModel(model) - if (recommendedModel?.id === model?.id) setRecommendedModel(model) - } }, - [ - activeModelParams, - selectedModel, - setThreadModelParams, - preserveModelFeatureEnabled, - updateDownloadedModel, - setSelectedModel, - recommendedModel, - setRecommendedModel, - ] + [activeModelParams, selectedModel, setThreadModelParams] ) const processStopWords = (params: ModelParams): ModelParams => { diff --git a/web/screens/Hub/ModelList/ModelHeader/index.tsx b/web/screens/Hub/ModelList/ModelHeader/index.tsx index b20977affe..44a3fd2785 100644 --- a/web/screens/Hub/ModelList/ModelHeader/index.tsx +++ b/web/screens/Hub/ModelList/ModelHeader/index.tsx @@ -1,6 +1,6 @@ import { useCallback } from 'react' -import { Model } from '@janhq/core' +import { ModelFile } from '@janhq/core' import { Button, Badge, Tooltip } from '@janhq/joi' import { useAtomValue, useSetAtom } from 'jotai' @@ -38,7 +38,7 @@ import { } from '@/helpers/atoms/SystemBar.atom' type Props = { - model: Model + model: ModelFile onClick: () => void open: string } diff --git a/web/screens/Hub/ModelList/ModelItem/index.tsx b/web/screens/Hub/ModelList/ModelItem/index.tsx index c9b2f13294..ec9d885a1d 100644 --- a/web/screens/Hub/ModelList/ModelItem/index.tsx +++ b/web/screens/Hub/ModelList/ModelItem/index.tsx @@ -1,6 +1,6 @@ import { useState } from 'react' -import { Model } from '@janhq/core' +import { ModelFile } from '@janhq/core' import { Badge } from '@janhq/joi' import { twMerge } from 'tailwind-merge' @@ -12,7 +12,7 @@ import ModelItemHeader from '@/screens/Hub/ModelList/ModelHeader' import { toGibibytes } from '@/utils/converter' type Props = { - model: Model + model: ModelFile } const ModelItem: React.FC = ({ model }) => { diff --git a/web/screens/Hub/ModelList/index.tsx b/web/screens/Hub/ModelList/index.tsx index aea67b4e3b..8fc30d5418 100644 --- a/web/screens/Hub/ModelList/index.tsx +++ b/web/screens/Hub/ModelList/index.tsx @@ -1,6 +1,6 @@ import { useMemo } from 'react' -import { Model } from '@janhq/core' +import { ModelFile } from '@janhq/core' import { useAtomValue } from 'jotai' @@ -9,16 +9,16 @@ import ModelItem from '@/screens/Hub/ModelList/ModelItem' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' type Props = { - models: Model[] + models: ModelFile[] } const ModelList = ({ models }: Props) => { const downloadedModels = useAtomValue(downloadedModelsAtom) - const sortedModels: Model[] = useMemo(() => { - const featuredModels: Model[] = [] - const remoteModels: Model[] = [] - const localModels: Model[] = [] - const remainingModels: Model[] = [] + const sortedModels: ModelFile[] = useMemo(() => { + const featuredModels: ModelFile[] = [] + const remoteModels: ModelFile[] = [] + const localModels: ModelFile[] = [] + const remainingModels: ModelFile[] = [] models.forEach((m) => { if (m.metadata?.tags?.includes('Featured')) { featuredModels.push(m) diff --git a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx index 951a11d59e..c3f09f1713 100644 --- a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx +++ b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx @@ -53,7 +53,7 @@ const ModelDownloadRow: React.FC = ({ const { requestCreateNewThread } = useCreateNewThread() const setMainViewState = useSetAtom(mainViewStateAtom) const assistants = useAtomValue(assistantsAtom) - const isDownloaded = downloadedModels.find((md) => md.id === fileName) != null + const downloadedModel = downloadedModels.find((md) => md.id === fileName) const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom) const defaultModel = useAtomValue(defaultModelAtom) @@ -100,12 +100,12 @@ const ModelDownloadRow: React.FC = ({ alert('No assistant available') return } - await requestCreateNewThread(assistants[0], model) + await requestCreateNewThread(assistants[0], downloadedModel) setMainViewState(MainViewState.Thread) setHfImportingStage('NONE') }, [ assistants, - model, + downloadedModel, requestCreateNewThread, setMainViewState, setHfImportingStage, @@ -139,7 +139,7 @@ const ModelDownloadRow: React.FC = ({ - {isDownloaded ? ( + {downloadedModel ? (