Skip to content

Commit

Permalink
fix(InferenceExtension): janhq#1067 sync the nitro process state (jan…
Browse files Browse the repository at this point in the history
…hq#1493)

Signed-off-by: James <[email protected]>
Co-authored-by: James <[email protected]>
  • Loading branch information
namchuai and James authored Jan 10, 2024
1 parent 31fdd89 commit 9183330
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 61 deletions.
126 changes: 75 additions & 51 deletions extensions/inference-nitro-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import {
fs,
Model,
joinPath,
InferenceExtension,
} from "@janhq/core";
import { InferenceExtension } from "@janhq/core";
import { requestInference } from "./helpers/sse";
import { ulid } from "ulid";
import { join } from "path";
Expand All @@ -36,9 +36,14 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
private static readonly _settingsDir = "file://settings";
private static readonly _engineMetadataFileName = "nitro.json";

private static _currentModel: Model;
/**
* Checking the health for Nitro's process each 5 secs.
*/
private static readonly _intervalHealthCheck = 5 * 1000;

private _currentModel: Model;

private static _engineSettings: EngineSettings = {
private _engineSettings: EngineSettings = {
ctx_len: 2048,
ngl: 100,
cpu_threads: 1,
Expand All @@ -48,6 +53,18 @@ export default class JanInferenceNitroExtension implements InferenceExtension {

controller = new AbortController();
isCancelled = false;

/**
* The interval id for the health check. Used to stop the health check.
*/
private getNitroProcesHealthIntervalId: NodeJS.Timeout | undefined =
undefined;

/**
* Tracking the current state of nitro process.
*/
private nitroProcessInfo: any = undefined;

/**
* Returns the type of the extension.
* @returns {ExtensionType} The type of the extension.
Expand All @@ -71,21 +88,13 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
this.writeDefaultEngineSettings();

// Events subscription
events.on(EventName.OnMessageSent, (data) =>
JanInferenceNitroExtension.handleMessageRequest(data, this)
);
events.on(EventName.OnMessageSent, (data) => this.onMessageRequest(data));

events.on(EventName.OnModelInit, (model: Model) => {
JanInferenceNitroExtension.handleModelInit(model);
});
events.on(EventName.OnModelInit, (model: Model) => this.onModelInit(model));

events.on(EventName.OnModelStop, (model: Model) => {
JanInferenceNitroExtension.handleModelStop(model);
});
events.on(EventName.OnModelStop, (model: Model) => this.onModelStop(model));

events.on(EventName.OnInferenceStopped, () => {
JanInferenceNitroExtension.handleInferenceStopped(this);
});
events.on(EventName.OnInferenceStopped, () => this.onInferenceStopped());

// Attempt to fetch nvidia info
await executeOnMain(MODULE, "updateNvidiaInfo", {});
Expand All @@ -104,23 +113,22 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
);
if (await fs.existsSync(engineFile)) {
const engine = await fs.readFileSync(engineFile, "utf-8");
JanInferenceNitroExtension._engineSettings =
this._engineSettings =
typeof engine === "object" ? engine : JSON.parse(engine);
} else {
await fs.writeFileSync(
engineFile,
JSON.stringify(JanInferenceNitroExtension._engineSettings, null, 2)
JSON.stringify(this._engineSettings, null, 2)
);
}
} catch (err) {
console.error(err);
}
}

private static async handleModelInit(model: Model) {
if (model.engine !== "nitro") {
return;
}
private async onModelInit(model: Model) {
if (model.engine !== "nitro") return;

const modelFullPath = await joinPath(["models", model.id]);

const nitroInitResult = await executeOnMain(MODULE, "initModel", {
Expand All @@ -130,26 +138,49 @@ export default class JanInferenceNitroExtension implements InferenceExtension {

if (nitroInitResult.error === null) {
events.emit(EventName.OnModelFail, model);
} else {
JanInferenceNitroExtension._currentModel = model;
events.emit(EventName.OnModelReady, model);
return;
}

this._currentModel = model;
events.emit(EventName.OnModelReady, model);

this.getNitroProcesHealthIntervalId = setInterval(
() => this.periodicallyGetNitroHealth(),
JanInferenceNitroExtension._intervalHealthCheck
);
}

private static async handleModelStop(model: Model) {
if (model.engine !== "nitro") {
return;
} else {
await executeOnMain(MODULE, "stopModel");
events.emit(EventName.OnModelStopped, model);
private async onModelStop(model: Model) {
if (model.engine !== "nitro") return;

await executeOnMain(MODULE, "stopModel");
events.emit(EventName.OnModelStopped, {});

// stop the periocally health check
if (this.getNitroProcesHealthIntervalId) {
console.debug("Stop calling Nitro process health check");
clearInterval(this.getNitroProcesHealthIntervalId);
this.getNitroProcesHealthIntervalId = undefined;
}
}

/**
* Periodically check for nitro process's health.
*/
private async periodicallyGetNitroHealth(): Promise<void> {
const health = await executeOnMain(MODULE, "getCurrentNitroProcessInfo");

const isRunning = this.nitroProcessInfo?.isRunning ?? false;
if (isRunning && health.isRunning === false) {
console.debug("Nitro process is stopped");
events.emit(EventName.OnModelStopped, {});
}
this.nitroProcessInfo = health;
}

private static async handleInferenceStopped(
instance: JanInferenceNitroExtension
) {
instance.isCancelled = true;
instance.controller?.abort();
private async onInferenceStopped() {
this.isCancelled = true;
this.controller?.abort();
}

/**
Expand All @@ -171,10 +202,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
};

return new Promise(async (resolve, reject) => {
requestInference(
data.messages ?? [],
JanInferenceNitroExtension._currentModel
).subscribe({
requestInference(data.messages ?? [], this._currentModel).subscribe({
next: (_content) => {},
complete: async () => {
resolve(message);
Expand All @@ -192,13 +220,9 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private static async handleMessageRequest(
data: MessageRequest,
instance: JanInferenceNitroExtension
) {
if (data.model.engine !== "nitro") {
return;
}
private async onMessageRequest(data: MessageRequest) {
if (data.model.engine !== "nitro") return;

const timestamp = Date.now();
const message: ThreadMessage = {
id: ulid(),
Expand All @@ -213,13 +237,13 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
};
events.emit(EventName.OnMessageResponse, message);

instance.isCancelled = false;
instance.controller = new AbortController();
this.isCancelled = false;
this.controller = new AbortController();

requestInference(
data.messages ?? [],
{ ...JanInferenceNitroExtension._currentModel, ...data.model },
instance.controller
{ ...this._currentModel, ...data.model },
this.controller
).subscribe({
next: (content) => {
const messageContent: ThreadContent = {
Expand All @@ -239,7 +263,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
events.emit(EventName.OnMessageUpdate, message);
},
error: async (err) => {
if (instance.isCancelled || message.content.length) {
if (this.isCancelled || message.content.length) {
message.status = MessageStatus.Stopped;
events.emit(EventName.OnMessageUpdate, message);
return;
Expand Down
18 changes: 14 additions & 4 deletions extensions/inference-nitro-extension/src/module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ let subprocess = undefined;
let currentModelFile: string = undefined;
let currentSettings = undefined;

let nitroProcessInfo = undefined;

/**
* Stops a Nitro subprocess.
* @param wrapper - The model wrapper.
Expand Down Expand Up @@ -80,7 +82,7 @@ async function updateNvidiaDriverInfo(): Promise<void> {
);
}

function checkFileExistenceInPaths(file: string, paths: string[]): boolean {
function isExists(file: string, paths: string[]): boolean {
return paths.some((p) => existsSync(path.join(p, file)));
}

Expand All @@ -104,12 +106,12 @@ function updateCudaExistence() {
}

let cudaExists = filesCuda12.every(
(file) => existsSync(file) || checkFileExistenceInPaths(file, paths)
(file) => existsSync(file) || isExists(file, paths)
);

if (!cudaExists) {
cudaExists = filesCuda11.every(
(file) => existsSync(file) || checkFileExistenceInPaths(file, paths)
(file) => existsSync(file) || isExists(file, paths)
);
if (cudaExists) {
cudaVersion = "11";
Expand Down Expand Up @@ -461,7 +463,7 @@ function spawnNitroProcess(nitroResourceProbe: any): Promise<any> {
function getResourcesInfo(): Promise<ResourcesInfo> {
return new Promise(async (resolve) => {
const cpu = await osUtils.cpuCount();
console.log("cpu: ", cpu);
console.debug("cpu: ", cpu);
const response: ResourcesInfo = {
numCpuPhysicalCore: cpu,
memAvailable: 0,
Expand All @@ -470,6 +472,13 @@ function getResourcesInfo(): Promise<ResourcesInfo> {
});
}

const getCurrentNitroProcessInfo = (): Promise<any> => {
nitroProcessInfo = {
isRunning: subprocess != null,
};
return nitroProcessInfo;
};

function dispose() {
// clean other registered resources here
killSubprocess();
Expand All @@ -481,4 +490,5 @@ module.exports = {
killSubprocess,
dispose,
updateNvidiaInfo,
getCurrentNitroProcessInfo,
};
8 changes: 2 additions & 6 deletions web/containers/Providers/EventHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import {
ExtensionType,
MessageStatus,
Model,
ConversationalExtension,
} from '@janhq/core'
import { ConversationalExtension } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'

import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
Expand Down Expand Up @@ -64,14 +64,10 @@ export default function EventHandler({ children }: { children: ReactNode }) {
}))
}

async function handleModelStopped(model: Model) {
async function handleModelStopped() {
setTimeout(async () => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
// toaster({
// title: 'Success!',
// description: `Model ${model.id} has been stopped.`,
// })
}, 500)
}

Expand Down

0 comments on commit 9183330

Please sign in to comment.