Skip to content

Commit

Permalink
feat: Accurate Token Usage Tracking & Optional Balance (danny-avila#1018
Browse files Browse the repository at this point in the history
)

* refactor(Chains/llms): allow passing callbacks

* refactor(BaseClient): accurately count completion tokens as generation only

* refactor(OpenAIClient): remove unused getTokenCountForResponse, pass streaming var and callbacks in initializeLLM

* wip: summary prompt tokens

* refactor(summarizeMessages): new cut-off strategy that generates a better summary by adding context from beginning, truncating the middle, and providing the end
wip: draft out relevant providers and variables for token tracing

* refactor(createLLM): make streaming prop false by default

* chore: remove use of getTokenCountForResponse

* refactor(agents): use BufferMemory as ConversationSummaryBufferMemory token usage not easy to trace

* chore: remove passing of streaming prop, also console log useful vars for tracing

* feat: formatFromLangChain helper function to count tokens for ChatModelStart

* refactor(initializeLLM): add role for LLM tracing

* chore(formatFromLangChain): update JSDoc

* feat(formatMessages): formats langChain messages into OpenAI payload format

* chore: install openai-chat-tokens

* refactor(formatMessage): optimize conditional langChain logic
fix(formatFromLangChain): fix destructuring

* feat: accurate prompt tokens for ChatModelStart before generation

* refactor(handleChatModelStart): move to callbacks dir, use factory function

* refactor(initializeLLM): rename 'role' to 'context'

* feat(Balance/Transaction): new schema/models for tracking token spend
refactor(Key): factor out model export to separate file

* refactor(initializeClient): add req,res objects to client options

* feat: add-balance script to add to an existing users' token balance
refactor(Transaction): use multiplier map/function, return balance update

* refactor(Tx): update enum for tokenType, return 1 for multiplier if no map match

* refactor(Tx): add fair fallback value multiplier incase the config result is undefined

* refactor(Balance): rename 'tokens' to 'tokenCredits'

* feat: balance check, add tx.js for new tx-related methods and tests

* chore(summaryPrompts): update prompt token count

* refactor(callbacks): pass req, res
wip: check balance

* refactor(Tx): make convoId a String type, fix(calculateTokenValue)

* refactor(BaseClient): add conversationId as client prop when assigned

* feat(RunManager): track LLM runs with manager, track token spend from LLM,
refactor(OpenAIClient): use RunManager to create callbacks, pass user prop to langchain api calls

* feat(spendTokens): helper to spend prompt/completion tokens

* feat(checkBalance): add helper to check, log, deny request if balance doesn't have enough funds
refactor(Balance): static check method to return object instead of boolean now
wip(OpenAIClient): implement use of checkBalance

* refactor(initializeLLM): add token buffer to assure summary isn't generated when subsequent payload is too large
refactor(OpenAIClient): add checkBalance
refactor(createStartHandler): add checkBalance

* chore: remove prompt and completion token logging from route handler

* chore(spendTokens): add JSDoc

* feat(logTokenCost): record transactions for basic api calls

* chore(ask/edit): invoke getResponseSender only once per API call

* refactor(ask/edit): pass promptTokens to getIds and include in abort data

* refactor(getIds -> getReqData): rename function

* refactor(Tx): increase value if incomplete message

* feat: record tokenUsage when message is aborted

* refactor: subtract tokens when payload includes function_call

* refactor: add namespace for token_balance

* fix(spendTokens): only execute if corresponding token type amounts are defined

* refactor(checkBalance): throws Error if not enough token credits

* refactor(runTitleChain): pass and use signal, spread object props in create helpers, and use 'call' instead of 'run'

* fix(abortMiddleware): circular dependency, and default to empty string for completionTokens

* fix: properly cancel title requests when there isn't enough tokens to generate

* feat(predictNewSummary): custom chain for summaries to allow signal passing
refactor(summaryBuffer): use new custom chain

* feat(RunManager): add getRunByConversationId method, refactor: remove run and throw llm error on handleLLMError

* refactor(createStartHandler): if summary, add error details to runs

* fix(OpenAIClient): support aborting from summarization & showing error to user
refactor(summarizeMessages): remove unnecessary operations counting summaryPromptTokens and note for alternative, pass signal to summaryBuffer

* refactor(logTokenCost -> recordTokenUsage): rename

* refactor(checkBalance): include promptTokens in errorMessage

* refactor(checkBalance/spendTokens): move to models dir

* fix(createLanguageChain): correctly pass config

* refactor(initializeLLM/title): add tokenBuffer of 150 for balance check

* refactor(openAPIPlugin): pass signal and memory, filter functions by the one being called

* refactor(createStartHandler): add error to run if context is plugins as well

* refactor(RunManager/handleLLMError): throw error immediately if plugins, don't remove run

* refactor(PluginsClient): pass memory and signal to tools, cleanup error handling logic

* chore: use absolute equality for addTitle condition

* refactor(checkBalance): move checkBalance to execute after userMessage and tokenCounts are saved, also make conditional

* style: icon changes to match official

* fix(BaseClient): getTokenCountForResponse -> getTokenCount

* fix(formatLangChainMessages): add kwargs as fallback prop from lc_kwargs, update JSDoc

* refactor(Tx.create): does not update balance if CHECK_BALANCE is not enabled

* fix(e2e/cleanUp): cleanup new collections, import all model methods from index

* fix(config/add-balance): add uncaughtException listener

* fix: circular dependency

* refactor(initializeLLM/checkBalance): append new generations to errorMessage if cost exceeds balance

* fix(handleResponseMessage): only record token usage in this method if not error and completion is not skipped

* fix(createStartHandler): correct condition for generations

* chore: bump postcss due to moderate severity vulnerability

* chore: bump zod due to low severity vulnerability

* chore: bump openai & data-provider version

* feat(types): OpenAI Message types

* chore: update bun lockfile

* refactor(CodeBlock): add error block formatting

* refactor(utils/Plugin): factor out formatJSON and cn to separate files (json.ts and cn.ts), add extractJSON

* chore(logViolation): delete user_id after error is logged

* refactor(getMessageError -> Error): change to React.FC, add token_balance handling, use extractJSON to determine JSON instead of regex

* fix(DALL-E): use latest openai SDK

* chore: reorganize imports, fix type issue

* feat(server): add balance route

* fix(api/models): add auth

* feat(data-provider): /api/balance query

* feat: show balance if checking is enabled, refetch on final message or error

* chore: update docs, .env.example with token_usage info, add balance script command

* fix(Balance): fallback to empty obj for balance query

* style: slight adjustment of balance element

* docs(token_usage): add PR notes
  • Loading branch information
danny-avila authored Oct 5, 2023
1 parent be71a19 commit 365c39c
Show file tree
Hide file tree
Showing 81 changed files with 1,607 additions and 294 deletions.
15 changes: 15 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@ APP_TITLE=LibreChat
HOST=localhost
PORT=3080

# Note: the following enables user balances, which you can add manually
# or you will need to build out a balance accruing system for users.
# For more info, see https://docs.librechat.ai/features/token_usage.html

# To manually add balances, run the following command:
# `npm run add-balance`

# You can also specify the email and token credit amount to add, e.g.:
# `npm run add-balance [email protected] 1000`

# This works well to track your own usage for personal use; 1000 credits = $0.001 (1 mill USD)

# Set to true to enable token credit balances for the OpenAI/Plugins endpoints
CHECK_BALANCE=false

# Automated Moderation System
# The Automated Moderation System uses a scoring mechanism to track user violations. As users commit actions
# like excessive logins, registrations, or messaging, they accumulate violation scores. Upon reaching
Expand Down
36 changes: 30 additions & 6 deletions api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
const crypto = require('crypto');
const TextStream = require('./TextStream');
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('../../models');
const { addSpaceIfNeeded } = require('../../server/utils');
const { addSpaceIfNeeded, isEnabled } = require('../../server/utils');
const checkBalance = require('../../models/checkBalance');

class BaseClient {
constructor(apiKey, options = {}) {
Expand Down Expand Up @@ -39,6 +40,12 @@ class BaseClient {
throw new Error('Subclasses attempted to call summarizeMessages without implementing it');
}

async recordTokenUsage({ promptTokens, completionTokens }) {
if (this.options.debug) {
console.debug('`recordTokenUsage` not implemented.', { promptTokens, completionTokens });
}
}

getBuildMessagesOptions() {
throw new Error('Subclasses must implement getBuildMessagesOptions');
}
Expand All @@ -64,6 +71,7 @@ class BaseClient {
let responseMessageId = opts.responseMessageId ?? crypto.randomUUID();
let head = isEdited ? responseMessageId : parentMessageId;
this.currentMessages = (await this.loadHistory(conversationId, head)) ?? [];
this.conversationId = conversationId;

if (isEdited && !isContinued) {
responseMessageId = crypto.randomUUID();
Expand Down Expand Up @@ -114,8 +122,8 @@ class BaseClient {
text: message,
});

if (typeof opts?.getIds === 'function') {
opts.getIds({
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessage,
conversationId,
responseMessageId,
Expand Down Expand Up @@ -420,6 +428,21 @@ class BaseClient {
await this.saveMessageToDatabase(userMessage, saveOptions, user);
}

if (isEnabled(process.env.CHECK_BALANCE)) {
await checkBalance({
req: this.options.req,
res: this.options.res,
txData: {
user: this.user,
tokenType: 'prompt',
amount: promptTokens,
debug: this.options.debug,
model: this.modelOptions.model,
},
});
}

const completion = await this.sendCompletion(payload, opts);
const responseMessage = {
messageId: responseMessageId,
conversationId,
Expand All @@ -428,14 +451,15 @@ class BaseClient {
isEdited,
model: this.modelOptions.model,
sender: this.sender,
text: addSpaceIfNeeded(generation) + (await this.sendCompletion(payload, opts)),
text: addSpaceIfNeeded(generation) + completion,
promptTokens,
};

if (tokenCountMap && this.getTokenCountForResponse) {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
if (tokenCountMap && this.getTokenCount) {
responseMessage.tokenCount = this.getTokenCount(completion);
responseMessage.completionTokens = responseMessage.tokenCount;
}
await this.recordTokenUsage(responseMessage);
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount;
return responseMessage;
Expand Down
123 changes: 92 additions & 31 deletions api/app/clients/OpenAIClient.js
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
const BaseClient = require('./BaseClient');
const ChatGPTClient = require('./ChatGPTClient');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const ChatGPTClient = require('./ChatGPTClient');
const BaseClient = require('./BaseClient');
const { getModelMaxTokens, genAzureChatCompletion } = require('../../utils');
const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
const spendTokens = require('../../models/spendTokens');
const { createLLM, RunManager } = require('./llm');
const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document');
const { createLLM } = require('./llm');

// Cache to store Tiktoken instances
const tokenizersCache = {};
Expand Down Expand Up @@ -335,6 +336,10 @@ class OpenAIClient extends BaseClient {
result.tokenCountMap = tokenCountMap;
}

if (promptTokens >= 0 && typeof this.options.getReqData === 'function') {
this.options.getReqData({ promptTokens });
}

return result;
}

Expand Down Expand Up @@ -409,26 +414,24 @@ class OpenAIClient extends BaseClient {
return reply.trim();
}

getTokenCountForResponse(response) {
return this.getTokenCountForMessage({
role: 'assistant',
content: response.text,
});
}

initializeLLM({
model = 'gpt-3.5-turbo',
modelName,
temperature = 0.2,
presence_penalty = 0,
frequency_penalty = 0,
max_tokens,
streaming,
context,
tokenBuffer,
initialMessageCount,
}) {
const modelOptions = {
modelName: modelName ?? model,
temperature,
presence_penalty,
frequency_penalty,
user: this.user,
};

if (max_tokens) {
Expand All @@ -451,11 +454,22 @@ class OpenAIClient extends BaseClient {
};
}

const { req, res, debug } = this.options;
const runManager = new RunManager({ req, res, debug, abortController: this.abortController });
this.runManager = runManager;

const llm = createLLM({
modelOptions,
configOptions,
openAIApiKey: this.apiKey,
azure: this.azure,
streaming,
callbacks: runManager.createCallbacks({
context,
tokenBuffer,
conversationId: this.conversationId,
initialMessageCount,
}),
});

return llm;
Expand All @@ -471,19 +485,24 @@ class OpenAIClient extends BaseClient {
const { OPENAI_TITLE_MODEL } = process.env ?? {};

const modelOptions = {
model: OPENAI_TITLE_MODEL ?? 'gpt-3.5-turbo-0613',
model: OPENAI_TITLE_MODEL ?? 'gpt-3.5-turbo',
temperature: 0.2,
presence_penalty: 0,
frequency_penalty: 0,
max_tokens: 16,
};

try {
const llm = this.initializeLLM(modelOptions);
title = await runTitleChain({ llm, text, convo });
this.abortController = new AbortController();
const llm = this.initializeLLM({ ...modelOptions, context: 'title', tokenBuffer: 150 });
title = await runTitleChain({ llm, text, convo, signal: this.abortController.signal });
} catch (e) {
if (e?.message?.toLowerCase()?.includes('abort')) {
this.options.debug && console.debug('Aborted title generation');
return;
}
console.log('There was an issue generating title with LangChain, trying the old method...');
console.error(e.message, e);
this.options.debug && console.error(e.message, e);
modelOptions.model = OPENAI_TITLE_MODEL ?? 'gpt-3.5-turbo';
const instructionsPayload = [
{
Expand Down Expand Up @@ -514,11 +533,19 @@ ${convo}
let context = messagesToRefine;
let prompt;

const { OPENAI_SUMMARY_MODEL } = process.env ?? {};
const { OPENAI_SUMMARY_MODEL = 'gpt-3.5-turbo' } = process.env ?? {};
const maxContextTokens = getModelMaxTokens(OPENAI_SUMMARY_MODEL) ?? 4095;

// Token count of messagesToSummarize: start with 3 tokens for the assistant label
const excessTokenCount = context.reduce((acc, message) => acc + message.tokenCount, 3);
// 3 tokens for the assistant label, and 98 for the summarizer prompt (101)
let promptBuffer = 101;

/*
* Note: token counting here is to block summarization if it exceeds the spend; complete
* accuracy is not important. Actual spend will happen after successful summarization.
*/
const excessTokenCount = context.reduce(
(acc, message) => acc + message.tokenCount,
promptBuffer,
);

if (excessTokenCount > maxContextTokens) {
({ context } = await this.getMessagesWithinTokenLimit(context, maxContextTokens));
Expand All @@ -528,30 +555,38 @@ ${convo}
this.options.debug &&
console.debug('Summary context is empty, using latest message within token limit');

promptBuffer = 32;
const { text, ...latestMessage } = messagesToRefine[messagesToRefine.length - 1];
const splitText = await tokenSplit({
text,
chunkSize: maxContextTokens - 40,
returnSize: 1,
chunkSize: Math.floor((maxContextTokens - promptBuffer) / 3),
});

const newText = splitText[0];

if (newText.length < text.length) {
prompt = CUT_OFF_PROMPT;
}
const newText = `${splitText[0]}\n...[truncated]...\n${splitText[splitText.length - 1]}`;
prompt = CUT_OFF_PROMPT;

context = [
{
...latestMessage,
text: newText,
},
formatMessage({
message: {
...latestMessage,
text: newText,
},
userName: this.options?.name,
assistantName: this.options?.chatGptLabel,
}),
];
}
// TODO: We can accurately count the tokens here before handleChatModelStart
// by recreating the summary prompt (single message) to avoid LangChain handling

const initialPromptTokens = this.maxContextTokens - remainingContextTokens;
this.options.debug && console.debug(`initialPromptTokens: ${initialPromptTokens}`);

const llm = this.initializeLLM({
model: OPENAI_SUMMARY_MODEL,
temperature: 0.2,
context: 'summary',
tokenBuffer: initialPromptTokens,
});

try {
Expand All @@ -565,6 +600,7 @@ ${convo}
assistantName: this.options?.chatGptLabel ?? this.options?.modelLabel,
},
previous_summary: this.previous_summary?.summary,
signal: this.abortController.signal,
});

const summaryTokenCount = this.getTokenCountForMessage(summaryMessage);
Expand All @@ -580,11 +616,36 @@ ${convo}

return { summaryMessage, summaryTokenCount };
} catch (e) {
console.error('Error refining messages');
console.error(e);
if (e?.message?.toLowerCase()?.includes('abort')) {
this.options.debug && console.debug('Aborted summarization');
const { run, runId } = this.runManager.getRunByConversationId(this.conversationId);
if (run && run.error) {
const { error } = run;
this.runManager.removeRun(runId);
throw new Error(error);
}
}
console.error('Error summarizing messages');
this.options.debug && console.error(e);
return {};
}
}

async recordTokenUsage({ promptTokens, completionTokens }) {
if (this.options.debug) {
console.debug('promptTokens', promptTokens);
console.debug('completionTokens', completionTokens);
}
await spendTokens(
{
user: this.user,
model: this.modelOptions.model,
context: 'message',
conversationId: this.conversationId,
},
{ promptTokens, completionTokens },
);
}
}

module.exports = OpenAIClient;
Loading

0 comments on commit 365c39c

Please sign in to comment.