Skip to content

Commit

Permalink
fix(client): optimize token counting algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
waylaidwanderer committed Mar 5, 2023
1 parent 011b0c4 commit 590b24b
Showing 1 changed file with 35 additions and 54 deletions.
89 changes: 35 additions & 54 deletions src/ChatGPTClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import { fetchEventSource } from '@waylaidwanderer/fetch-event-source';
import { Agent } from 'undici';

const CHATGPT_MODEL = 'text-chat-davinci-002-sh-alpha-aoruigiofdj83';
const CHATGPT_TOKENIZER = get_encoding('cl100k_base');

export default class ChatGPTClient {
constructor(
Expand Down Expand Up @@ -51,9 +50,10 @@ export default class ChatGPTClient {
if (isChatGptModel) {
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
// without tripping the stop sequences, so I'm using "##im_start##" instead.
// without tripping the stop sequences, so I'm using "||>" instead.
this.startToken = '||>';
this.endToken = '';
this.gptEncoder = get_encoding('cl100k_base');
} else if (isUnofficialChatGptModel) {
this.startToken = '<|im_start|>';
this.endToken = '<|im_end|>';
Expand All @@ -62,9 +62,16 @@ export default class ChatGPTClient {
'<|im_end|>': 100265,
});
} else {
this.startToken = '<|endoftext|>';
this.endToken = this.startToken;
this.gptEncoder = encoding_for_model('text-davinci-003');
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
// as a single token. So we're using this instead.
this.startToken = '||>';
this.endToken = '';
try {
this.gptEncoder = encoding_for_model(this.modelOptions.model);
} catch {
this.gptEncoder = encoding_for_model('text-davinci-003');
}
}

if (!this.modelOptions.stop) {
Expand Down Expand Up @@ -342,10 +349,7 @@ export default class ChatGPTClient {

let currentTokenCount;
if (isChatGptModel) {
currentTokenCount = this.constructor.getTokenCountForMessages([
instructionsPayload,
messagePayload,
]);
currentTokenCount = this.getTokenCountForMessage(instructionsPayload) + this.getTokenCountForMessage(messagePayload);
} else {
currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`);
}
Expand All @@ -370,21 +374,8 @@ export default class ChatGPTClient {
newPromptBody = `${promptPrefix}${messageString}${promptBody}`;
}

// The reason I don't simply get the token count of the messageString and add it to currentTokenCount is because
// joined words may combine into a single token. Actually, that isn't really applicable here, but I can't
// resist doing it the "proper" way.
let newTokenCount;
if (isChatGptModel) {
newTokenCount = this.constructor.getTokenCountForMessages([
instructionsPayload,
{
...messagePayload,
content: newPromptBody,
},
]);
} else {
newTokenCount = this.getTokenCount(`${newPromptBody}${promptSuffix}`);
}
const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage;
if (newTokenCount > maxTokenCount) {
if (promptBody) {
// This message would put us over the token limit, so don't add it.
Expand All @@ -395,6 +386,7 @@ export default class ChatGPTClient {
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setTimeout(resolve, 0));
return buildPromptBody();
}
Expand All @@ -404,19 +396,14 @@ export default class ChatGPTClient {
await buildPromptBody();

const prompt = `${promptBody}${promptSuffix}`;

let numTokens;
if (isChatGptModel) {
messagePayload.content = prompt;
numTokens = this.constructor.getTokenCountForMessages([
instructionsPayload,
messagePayload,
]);
} else {
numTokens = this.getTokenCount(prompt);
// Add 2 tokens for metadata after all messages have been counted.
currentTokenCount += 2;
}

// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min(this.maxContextTokens - numTokens, this.maxResponseTokens);
this.modelOptions.max_tokens = Math.min(this.maxContextTokens - currentTokenCount, this.maxResponseTokens);

if (isChatGptModel) {
return [
Expand All @@ -434,30 +421,24 @@ export default class ChatGPTClient {
/**
* Algorithm adapted from "6. Counting tokens for chat API calls" of
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
* @param {*[]} messages
*
* An additional 2 tokens need to be added for metadata after all messages have been counted.
*
* @param {*} message
*/
static getTokenCountForMessages(messages) {
// Get the encoding tokenizer
const tokenizer = CHATGPT_TOKENIZER;

// Map each message to the number of tokens it contains
const messageTokenCounts = messages.map((message) => {
// Map each property of the message to the number of tokens it contains
const propertyTokenCounts = Object.entries(message).map(([key, value]) => {
// Count the number of tokens in the property value
const numTokens = tokenizer.encode(value).length;

// Subtract 1 token if the property key is 'name'
const adjustment = (key === 'name') ? 1 : 0;
return numTokens - adjustment;
});

// Sum the number of tokens in all properties and add 4 for metadata
return propertyTokenCounts.reduce((a, b) => a + b, 4);
getTokenCountForMessage(message) {
// Map each property of the message to the number of tokens it contains
const propertyTokenCounts = Object.entries(message).map(([key, value]) => {
// Count the number of tokens in the property value
const numTokens = this.getTokenCount(value);

// Subtract 1 token if the property key is 'name'
const adjustment = (key === 'name') ? 1 : 0;
return numTokens - adjustment;
});

// Sum the number of tokens in all messages and add 2 for metadata
return messageTokenCounts.reduce((a, b) => a + b, 2);
// Sum the number of tokens in all properties and add 4 for metadata
return propertyTokenCounts.reduce((a, b) => a + b, 4);
}

/**
Expand Down

0 comments on commit 590b24b

Please sign in to comment.