Skip to content

Commit

Permalink
Bert
Browse files Browse the repository at this point in the history
  • Loading branch information
manyoso committed Jul 13, 2023
1 parent 315a1f2 commit 0efdbfc
Show file tree
Hide file tree
Showing 20 changed files with 673 additions and 389 deletions.
714 changes: 401 additions & 313 deletions gpt4all-backend/bert.cpp

Large diffs are not rendered by default.

71 changes: 0 additions & 71 deletions gpt4all-backend/bert.h

This file was deleted.

44 changes: 44 additions & 0 deletions gpt4all-backend/bert_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef BERT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#error This file is NOT meant to be included outside of bert.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define BERT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#endif
#ifndef BERT_H
#define BERT_H

#include <string>
#include <functional>
#include <vector>
#include <memory>
#include "llmodel.h"

struct BertPrivate;
class Bert : public LLModel {
public:
Bert();
~Bert();

bool supportsEmbedding() const override { return true; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;

std::vector<float> embedding(const std::string &text) override;

private:
std::unique_ptr<BertPrivate> d_ptr;

protected:
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
Token sampleToken(PromptContext &ctx) const override;
std::string tokenToString(Token) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override;
const std::vector<Token>& endTokens() const override;
};

#endif // BERT_H
2 changes: 2 additions & 0 deletions gpt4all-backend/falcon_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class Falcon : public LLModel {
Falcon();
~Falcon();

bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
Expand Down
2 changes: 2 additions & 0 deletions gpt4all-backend/gptj_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class GPTJ : public LLModel {
GPTJ();
~GPTJ();

bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
Expand Down
2 changes: 2 additions & 0 deletions gpt4all-backend/llamamodel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class LLamaModel : public LLModel {
LLamaModel();
~LLamaModel();

bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
Expand Down
7 changes: 7 additions & 0 deletions gpt4all-backend/llmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,25 @@ class LLModel {
explicit LLModel() {}
virtual ~LLModel() {}

virtual bool supportsEmbedding() const = 0;
virtual bool supportsCompletion() const = 0;
virtual bool loadModel(const std::string &modelPath) = 0;
virtual bool isModelLoaded() const = 0;
virtual size_t requiredMem(const std::string &modelPath) = 0;
virtual size_t stateSize() const { return 0; }
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }

// This method requires the model to return true from supportsCompletion otherwise it will throw
// an error
virtual void prompt(const std::string &prompt,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &ctx);

virtual std::vector<float> embedding(const std::string &text);

virtual void setThreadCount(int32_t /*n_threads*/) {}
virtual int32_t threadCount() const { return 1; }

Expand Down
19 changes: 19 additions & 0 deletions gpt4all-backend/llmodel_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,25 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
ctx->context_erase = wrapper->promptContext.contextErase;
}

float *llmodel_embedding(llmodel_model model, const char *text, size_t *embedding_size)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
std::vector<float> embeddingVector = wrapper->llModel->embedding(text);
float *embedding = (float *)malloc(embeddingVector.size() * sizeof(float));
if(embedding == nullptr) {
*embedding_size = 0;
return nullptr;
}
std::copy(embeddingVector.begin(), embeddingVector.end(), embedding);
*embedding_size = embeddingVector.size();
return embedding;
}

void llmodel_free_embedding(float *ptr)
{
free(ptr);
}

void llmodel_setThreadCount(llmodel_model model, int32_t n_threads)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
Expand Down
17 changes: 17 additions & 0 deletions gpt4all-backend/llmodel_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,23 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
llmodel_recalculate_callback recalculate_callback,
llmodel_prompt_context *ctx);

/**
* Generate an embedding using the model.
* @param model A pointer to the llmodel_model instance.
* @param text A string representing the text to generate an embedding for.
* @param embedding_size A pointer to a size_t type that will be set by the call indicating the length
* of the returned floating point array.
* @return A pointer to an array of floating point values passed to the calling method which then will
* be responsible for lifetime of this memory.
*/
float *llmodel_embedding(llmodel_model model, const char *text, size_t *embedding_size);

/**
* Frees the memory allocated by the llmodel_embedding function.
* @param ptr A pointer to the embedding as returned from llmodel_embedding.
*/
void llmodel_free_embedding(float *ptr);

/**
* Set the number of threads to be used by the model.
* @param model A pointer to the llmodel_model instance.
Expand Down
16 changes: 16 additions & 0 deletions gpt4all-backend/llmodel_shared.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ void LLModel::prompt(const std::string &prompt,
return;
}

if (!supportsCompletion()) {
std::string errorMessage = "ERROR: this model does not support text completion or chat!\n";
responseCallback(-1, errorMessage);
std::cerr << implementation().modelType() << errorMessage;
return;
}

// tokenize the prompt
std::vector<Token> embd_inp = tokenize(promptCtx, prompt);

Expand Down Expand Up @@ -158,3 +165,12 @@ void LLModel::prompt(const std::string &prompt,
cachedTokens.clear();
}
}

std::vector<float> LLModel::embedding(const std::string &/*text*/)
{
if (!supportsCompletion()) {
std::string errorMessage = "ERROR: this model does not support generating embeddings!\n";
std::cerr << implementation().modelType() << errorMessage;
}
return std::vector<float>();
}
2 changes: 2 additions & 0 deletions gpt4all-backend/mpt_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class MPT : public LLModel {
MPT();
~MPT();

bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
Expand Down
2 changes: 2 additions & 0 deletions gpt4all-backend/replit_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class Replit : public LLModel {
Replit();
~Replit();

bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string & modelPath) override;
Expand Down
102 changes: 102 additions & 0 deletions gpt4all-backend/scripts/convert_bert_hf_to_ggml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import sys
import struct
import json
import torch
import numpy as np

from transformers import AutoModel, AutoTokenizer

if len(sys.argv) < 3:
print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
print(" ftype == 0 -> float32")
print(" ftype == 1 -> float16")
sys.exit(1)

# output in the same directory as the model
dir_model = sys.argv[1]
fname_out = sys.argv[1] + "/ggml-model.bin"

with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:
encoder = json.load(f)

with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)

with open(dir_model + "/vocab.txt", "r", encoding="utf-8") as f:
vocab = f.readlines()
# possible data types
# ftype == 0 -> float32
# ftype == 1 -> float16
#
# map from ftype to string
ftype_str = ["f32", "f16"]

ftype = 1
if len(sys.argv) > 2:
ftype = int(sys.argv[2])
if ftype < 0 or ftype > 1:
print("Invalid ftype: " + str(ftype))
sys.exit(1)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"


tokenizer = AutoTokenizer.from_pretrained(dir_model)
model = AutoModel.from_pretrained(dir_model, low_cpu_mem_usage=True)
print (model)

print(tokenizer.encode('I believe the meaning of life is'))

list_vars = model.state_dict()
for name in list_vars.keys():
print(name, list_vars[name].shape, list_vars[name].dtype)

fout = open(fname_out, "wb")

print(hparams)

fout.write(struct.pack("i", 0x62657274)) # magic: ggml in hex
fout.write(struct.pack("i", hparams["vocab_size"]))
fout.write(struct.pack("i", hparams["max_position_embeddings"]))
fout.write(struct.pack("i", hparams["hidden_size"]))
fout.write(struct.pack("i", hparams["intermediate_size"]))
fout.write(struct.pack("i", hparams["num_attention_heads"]))
fout.write(struct.pack("i", hparams["num_hidden_layers"]))
fout.write(struct.pack("i", ftype))

for i in range(hparams["vocab_size"]):
text = vocab[i][:-1] # strips newline at the end
#print(f"{i}:{text}")
data = bytes(text, 'utf-8')
fout.write(struct.pack("i", len(data)))
fout.write(data)

for name in list_vars.keys():
data = list_vars[name].squeeze().numpy()
if name in ['embeddings.position_ids', 'pooler.dense.weight', 'pooler.dense.bias']:
continue
print("Processing variable: " + name + " with shape: ", data.shape)

n_dims = len(data.shape);

# ftype == 0 -> float32, ftype == 1 -> float16
if ftype == 1 and name[-7:] == ".weight" and n_dims == 2:
print(" Converting to float16")
data = data.astype(np.float16)
l_type = 1
else:
l_type = 0

# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), l_type))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str);

# data
data.tofile(fout)

fout.close()

print("Done. Output file: " + fname_out)
print("")
2 changes: 1 addition & 1 deletion gpt4all-bindings/python/gpt4all/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .gpt4all import GPT4All # noqa
from .gpt4all import GPT4All, embed # noqa
from .pyllmodel import LLModel # noqa
14 changes: 14 additions & 0 deletions gpt4all-bindings/python/gpt4all/gpt4all.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@
# TODO: move to config
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")

def embed(
text: str
) -> list[float]:
"""
Generate an embedding for all GPT4All.
Args:
text: The text document to generate an embedding for.
Returns:
An embedding of your document of text.
"""
model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin')
return model.model.generate_embedding(text)

class GPT4All:
"""
Expand Down
Loading

0 comments on commit 0efdbfc

Please sign in to comment.