Skip to content

Commit

Permalink
chat: faster KV shift, continue generating, fix stop sequences (nomic…
Browse files Browse the repository at this point in the history
…-ai#2781)

* Don't stop generating at end of context
* Use llama_kv_cache ops to shift context
* Fix and improve reverse prompt detection
* Replace prompt recalc callback with a flag to disallow context shift
  • Loading branch information
cebtenzzre authored Aug 7, 2024
1 parent 90de2d3 commit be66ec8
Show file tree
Hide file tree
Showing 16 changed files with 280 additions and 225 deletions.
2 changes: 1 addition & 1 deletion gpt4all-backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ set(LLMODEL_VERSION_PATCH 0)
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)

set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 23)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
set(BUILD_SHARED_LIBS ON)
Expand Down
35 changes: 31 additions & 4 deletions gpt4all-backend/llamamodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,7 @@ size_t LLamaModel::restoreState(const uint8_t *src)
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
{
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || (
llama_token_get_attr(d_ptr->model, m_tokenize_last_token)
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)
);
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
std::vector<LLModel::Token> fres(str.length() + 4);
int32_t fres_len = llama_tokenize_gpt4all(
d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
Expand All @@ -546,6 +543,12 @@ std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::
return fres;
}

bool LLamaModel::isSpecialToken(Token id) const
{
return llama_token_get_attr(d_ptr->model, id)
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN);
}

std::string LLamaModel::tokenToString(Token id) const
{
std::vector<char> result(8, 0);
Expand Down Expand Up @@ -595,6 +598,30 @@ bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &toke
return res == 0;
}

void LLamaModel::shiftContext(PromptContext &promptCtx)
{
// infinite text generation via context shifting

// erase up to n_ctx*contextErase tokens
int n_keep = shouldAddBOS();
int n_past = promptCtx.n_past;
int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase));

assert(n_discard > 0);
if (n_discard <= 0)
return;

std::cerr << "Llama: context full, swapping: n_past = " << n_past << ", n_keep = " << n_keep
<< ", n_discard = " << n_discard << "\n";

// erase the first n_discard tokens from the context
llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard);
llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard);

promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
promptCtx.n_past = promptCtx.tokens.size();
}

int32_t LLamaModel::contextLength() const
{
return llama_n_ctx(d_ptr->ctx);
Expand Down
3 changes: 2 additions & 1 deletion gpt4all-backend/llamamodel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "llmodel.h"

#include <functional>
#include <memory>
#include <string>
#include <vector>
Expand Down Expand Up @@ -54,9 +53,11 @@ class LLamaModel : public LLModel {

protected:
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override;
bool isSpecialToken(Token id) const override;
std::string tokenToString(Token id) const override;
Token sampleToken(PromptContext &ctx) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
void shiftContext(PromptContext &promptCtx) override;
int32_t contextLength() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override;
Expand Down
14 changes: 6 additions & 8 deletions gpt4all-backend/llmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class LLModel {
int32_t n_batch = 9;
float repeat_penalty = 1.10f;
int32_t repeat_last_n = 64; // last n tokens to penalize
float contextErase = 0.75f; // percent of context to erase if we exceed the context window
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
};

using ProgressCallback = std::function<bool(float progress)>;
Expand All @@ -159,7 +159,7 @@ class LLModel {
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &ctx,
bool special = false,
std::string *fakeReply = nullptr);
Expand Down Expand Up @@ -213,9 +213,11 @@ class LLModel {
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0;
virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0;
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0;
virtual int32_t contextLength() const = 0;
virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0;
Expand All @@ -232,10 +234,6 @@ class LLModel {
return -1;
}

// This is a helper function called from the default implementation of 'prompt' but it can be
// shared by all base classes so it isn't virtual
void recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate);

const Implementation *m_implementation = nullptr;

ProgressCallback m_progressCallback;
Expand All @@ -249,11 +247,11 @@ class LLModel {

bool decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx);

Token m_tokenize_last_token = -1; // not serialized
Expand Down
4 changes: 2 additions & 2 deletions gpt4all-backend/llmodel_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
bool allow_context_shift,
llmodel_prompt_context *ctx,
bool special,
const char *fake_reply)
Expand Down Expand Up @@ -135,7 +135,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;

// Call the C++ prompt method
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, recalculate_callback,
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
wrapper->promptContext, special, fake_reply_p);

// Update the C context by giving access to the wrappers raw pointers to std::vector data
Expand Down
11 changes: 2 additions & 9 deletions gpt4all-backend/llmodel_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,6 @@ typedef bool (*llmodel_prompt_callback)(int32_t token_id);
*/
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);

/**
* Callback type for recalculation of context.
* @param whether the model is recalculating the context.
* @return a bool indicating whether the model should keep generating.
*/
typedef bool (*llmodel_recalculate_callback)(bool is_recalculating);

/**
* Embedding cancellation callback for use with llmodel_embed.
* @param batch_sizes The number of tokens in each batch that will be embedded.
Expand Down Expand Up @@ -175,7 +168,7 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
* @param prompt_template A string representing the input prompt template.
* @param prompt_callback A callback function for handling the processing of prompt.
* @param response_callback A callback function for handling the generated response.
* @param recalculate_callback A callback function for handling recalculation requests.
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
* @param special True if special tokens in the prompt should be processed, false otherwise.
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
* @param ctx A pointer to the llmodel_prompt_context structure.
Expand All @@ -184,7 +177,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
bool allow_context_shift,
llmodel_prompt_context *ctx,
bool special,
const char *fake_reply);
Expand Down
Loading

0 comments on commit be66ec8

Please sign in to comment.