Skip to content

Commit

Permalink
🦙 feat: Ollama Vision Support (danny-avila#2643)
Browse files Browse the repository at this point in the history
* refactor: checkVisionRequest, search availableModels for valid vision model instead of using default

* feat: install ollama-js, add typedefs

* feat: Ollama Vision Support

* ci: fix test
  • Loading branch information
danny-avila authored May 9, 2024
1 parent 3c5fa40 commit c94278b
Show file tree
Hide file tree
Showing 12 changed files with 388 additions and 115 deletions.
154 changes: 154 additions & 0 deletions api/app/clients/OllamaClient.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
const { z } = require('zod');
const axios = require('axios');
const { Ollama } = require('ollama');
const { deriveBaseURL } = require('~/utils');
const { logger } = require('~/config');

const ollamaPayloadSchema = z.object({
mirostat: z.number().optional(),
mirostat_eta: z.number().optional(),
mirostat_tau: z.number().optional(),
num_ctx: z.number().optional(),
repeat_last_n: z.number().optional(),
repeat_penalty: z.number().optional(),
temperature: z.number().optional(),
seed: z.number().nullable().optional(),
stop: z.array(z.string()).optional(),
tfs_z: z.number().optional(),
num_predict: z.number().optional(),
top_k: z.number().optional(),
top_p: z.number().optional(),
stream: z.optional(z.boolean()),
model: z.string(),
});

/**
* @param {string} imageUrl
* @returns {string}
* @throws {Error}
*/
const getValidBase64 = (imageUrl) => {
const parts = imageUrl.split(';base64,');

if (parts.length === 2) {
return parts[1];
} else {
logger.error('Invalid or no Base64 string found in URL.');
}
};

class OllamaClient {
constructor(options = {}) {
const host = deriveBaseURL(options.baseURL ?? 'http://localhost:11434');
/** @type {Ollama} */
this.client = new Ollama({ host });
}

/**
* Fetches Ollama models from the specified base API path.
* @param {string} baseURL
* @returns {Promise<string[]>} The Ollama models.
*/
static async fetchModels(baseURL) {
let models = [];
if (!baseURL) {
return models;
}
try {
const ollamaEndpoint = deriveBaseURL(baseURL);
/** @type {Promise<AxiosResponse<OllamaListResponse>>} */
const response = await axios.get(`${ollamaEndpoint}/api/tags`);
models = response.data.models.map((tag) => tag.name);
return models;
} catch (error) {
const logMessage =
'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).';
logger.error(logMessage, error);
return [];
}
}

/**
* @param {ChatCompletionMessage[]} messages
* @returns {OllamaMessage[]}
*/
static formatOpenAIMessages(messages) {
const ollamaMessages = [];

for (const message of messages) {
if (typeof message.content === 'string') {
ollamaMessages.push({
role: message.role,
content: message.content,
});
continue;
}

let aggregatedText = '';
let imageUrls = [];

for (const content of message.content) {
if (content.type === 'text') {
aggregatedText += content.text + ' ';
} else if (content.type === 'image_url') {
imageUrls.push(getValidBase64(content.image_url.url));
}
}

const ollamaMessage = {
role: message.role,
content: aggregatedText.trim(),
};

if (imageUrls.length > 0) {
ollamaMessage.images = imageUrls;
}

ollamaMessages.push(ollamaMessage);
}

return ollamaMessages;
}

/***
* @param {Object} params
* @param {ChatCompletionPayload} params.payload
* @param {onTokenProgress} params.onProgress
* @param {AbortController} params.abortController
*/
async chatCompletion({ payload, onProgress, abortController = null }) {
let intermediateReply = '';

const parameters = ollamaPayloadSchema.parse(payload);
const messages = OllamaClient.formatOpenAIMessages(payload.messages);

if (parameters.stream) {
const stream = await this.client.chat({
messages,
...parameters,
});

for await (const chunk of stream) {
const token = chunk.message.content;
intermediateReply += token;
onProgress(token);
if (abortController.signal.aborted) {
stream.controller.abort();
break;
}
}
}
// TODO: regular completion
else {
// const generation = await this.client.generate(payload);
}

return intermediateReply;
}
catch(err) {
logger.error('[OllamaClient.chatCompletion]', err);
throw err;
}
}

module.exports = { OllamaClient, ollamaPayloadSchema };
63 changes: 53 additions & 10 deletions api/app/clients/OpenAIClient.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const OpenAI = require('openai');
const { OllamaClient } = require('./OllamaClient');
const { HttpsProxyAgent } = require('https-proxy-agent');
const {
Constants,
Expand Down Expand Up @@ -234,23 +235,52 @@ class OpenAIClient extends BaseClient {
* @param {MongoFile[]} attachments
*/
checkVisionRequest(attachments) {
if (!attachments) {
return;
}

const availableModels = this.options.modelsConfig?.[this.options.endpoint];
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
if (!availableModels) {
return;
}

const visionModelAvailable = availableModels?.includes(this.defaultVisionModel);
if (
attachments &&
attachments.some((file) => file?.type && file?.type?.includes('image')) &&
visionModelAvailable &&
!this.isVisionModel
) {
this.modelOptions.model = this.defaultVisionModel;
this.isVisionModel = true;
let visionRequestDetected = false;
for (const file of attachments) {
if (file?.type?.includes('image')) {
visionRequestDetected = true;
break;
}
}
if (!visionRequestDetected) {
return;
}

this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
if (this.isVisionModel) {
delete this.modelOptions.stop;
return;
}

for (const model of availableModels) {
if (!validateVisionModel({ model, availableModels })) {
continue;
}
this.modelOptions.model = model;
this.isVisionModel = true;
delete this.modelOptions.stop;
return;
}

if (!availableModels.includes(this.defaultVisionModel)) {
return;
}
if (!validateVisionModel({ model: this.defaultVisionModel, availableModels })) {
return;
}

this.modelOptions.model = this.defaultVisionModel;
this.isVisionModel = true;
delete this.modelOptions.stop;
}

setupTokens() {
Expand Down Expand Up @@ -715,6 +745,10 @@ class OpenAIClient extends BaseClient {
* In case of failure, it will return the default title, "New Chat".
*/
async titleConvo({ text, conversationId, responseText = '' }) {
if (this.options.attachments) {
delete this.options.attachments;
}

let title = 'New Chat';
const convo = `||>User:
"${truncateText(text)}"
Expand Down Expand Up @@ -1124,6 +1158,15 @@ ${convo}
});
}

if (this.options.attachments && this.options.endpoint?.toLowerCase() === 'ollama') {
const ollamaClient = new OllamaClient({ baseURL });
return await ollamaClient.chatCompletion({
payload: modelOptions,
onProgress,
abortController,
});
}

let UnexpectedRoleError = false;
if (modelOptions.stream) {
const stream = await openai.beta.chat.completions
Expand Down
42 changes: 40 additions & 2 deletions api/app/clients/specs/OpenAIClient.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,19 @@ describe('OpenAIClient', () => {
azureOpenAIApiVersion: '2020-07-01-preview',
};

let originalWarn;

beforeAll(() => {
jest.spyOn(console, 'warn').mockImplementation(() => {});
originalWarn = console.warn;
console.warn = jest.fn();
});

afterAll(() => {
console.warn.mockRestore();
console.warn = originalWarn;
});

beforeEach(() => {
console.warn.mockClear();
});

beforeEach(() => {
Expand Down Expand Up @@ -662,4 +669,35 @@ describe('OpenAIClient', () => {
expect(constructorArgs.baseURL).toBe(expectedURL);
});
});

describe('checkVisionRequest functionality', () => {
let client;
const attachments = [{ type: 'image/png' }];

beforeEach(() => {
client = new OpenAIClient('test-api-key', {
endpoint: 'ollama',
modelOptions: {
model: 'initial-model',
},
modelsConfig: {
ollama: ['initial-model', 'llava', 'other-model'],
},
});

client.defaultVisionModel = 'non-valid-default-model';
});

afterEach(() => {
jest.restoreAllMocks();
});

it('should set "llava" as the model if it is the first valid model when default validation fails', () => {
client.checkVisionRequest(attachments);

expect(client.modelOptions.model).toBe('llava');
expect(client.isVisionModel).toBeTruthy();
expect(client.modelOptions.stop).toBeUndefined();
});
});
});
1 change: 1 addition & 0 deletions api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"multer": "^1.4.5-lts.1",
"nodejs-gpt": "^1.37.4",
"nodemailer": "^6.9.4",
"ollama": "^0.5.0",
"openai": "4.36.0",
"openai-chat-tokens": "^0.2.8",
"openid-client": "^5.4.2",
Expand Down
54 changes: 2 additions & 52 deletions api/server/services/ModelService.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,11 @@ const axios = require('axios');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider');
const { extractBaseURL, inputSchema, processModelData, logAxiosError } = require('~/utils');
const { OllamaClient } = require('~/app/clients/OllamaClient');
const getLogStores = require('~/cache/getLogStores');
const { logger } = require('~/config');

const { openAIApiKey, userProvidedOpenAI } = require('./Config/EndpointService').config;

/**
* Extracts the base URL from the provided URL.
* @param {string} fullURL - The full URL.
* @returns {string} The base URL.
*/
function deriveBaseURL(fullURL) {
try {
const parsedUrl = new URL(fullURL);
const protocol = parsedUrl.protocol;
const hostname = parsedUrl.hostname;
const port = parsedUrl.port;

// Check if the parsed URL components are meaningful
if (!protocol || !hostname) {
return fullURL;
}

// Reconstruct the base URL
return `${protocol}//${hostname}${port ? `:${port}` : ''}`;
} catch (error) {
logger.error('Failed to derive base URL', error);
return fullURL; // Return the original URL in case of any exception
}
}

/**
* Fetches Ollama models from the specified base API path.
* @param {string} baseURL
* @returns {Promise<string[]>} The Ollama models.
*/
const fetchOllamaModels = async (baseURL) => {
let models = [];
if (!baseURL) {
return models;
}
try {
const ollamaEndpoint = deriveBaseURL(baseURL);
/** @type {Promise<AxiosResponse<OllamaListResponse>>} */
const response = await axios.get(`${ollamaEndpoint}/api/tags`);
models = response.data.models.map((tag) => tag.name);
return models;
} catch (error) {
const logMessage =
'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).';
logger.error(logMessage, error);
return [];
}
};

/**
* Fetches OpenAI models from the specified base API path or Azure, based on the provided configuration.
*
Expand Down Expand Up @@ -92,7 +43,7 @@ const fetchModels = async ({
}

if (name && name.toLowerCase().startsWith('ollama')) {
return await fetchOllamaModels(baseURL);
return await OllamaClient.fetchModels(baseURL);
}

try {
Expand Down Expand Up @@ -281,7 +232,6 @@ const getGoogleModels = () => {

module.exports = {
fetchModels,
deriveBaseURL,
getOpenAIModels,
getChatGPTBrowserModels,
getAnthropicModels,
Expand Down
Loading

0 comments on commit c94278b

Please sign in to comment.