Skip to content

Commit

Permalink
Move the implementation out of llmodel class.
Browse files Browse the repository at this point in the history
  • Loading branch information
manyoso committed Jul 13, 2023
1 parent 64b409e commit 33557b1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 39 deletions.
22 changes: 11 additions & 11 deletions gpt4all-backend/llmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ static bool requires_avxonly() {
#endif
}

LLModel::Implementation::Implementation(Dlhandle &&dlhandle_) : dlhandle(new Dlhandle(std::move(dlhandle_))) {
LLImplementation::LLImplementation(Dlhandle &&dlhandle_) : dlhandle(new Dlhandle(std::move(dlhandle_))) {
auto get_model_type = dlhandle->get<const char *()>("get_model_type");
assert(get_model_type);
modelType = get_model_type();
Expand All @@ -54,7 +54,7 @@ LLModel::Implementation::Implementation(Dlhandle &&dlhandle_) : dlhandle(new Dlh
assert(construct_);
}

LLModel::Implementation::Implementation(Implementation &&o)
LLImplementation::LLImplementation(LLImplementation &&o)
: construct_(o.construct_)
, modelType(o.modelType)
, buildVariant(o.buildVariant)
Expand All @@ -63,19 +63,19 @@ LLModel::Implementation::Implementation(Implementation &&o)
o.dlhandle = nullptr;
}

LLModel::Implementation::~Implementation() {
LLImplementation::~LLImplementation() {
if (dlhandle) delete dlhandle;
}

bool LLModel::Implementation::isImplementation(const Dlhandle &dl) {
bool LLImplementation::isImplementation(const Dlhandle &dl) {
return dl.get<bool(uint32_t)>("is_g4a_backend_model_implementation");
}

const std::vector<LLModel::Implementation> &LLModel::implementationList() {
const std::vector<LLImplementation> &LLModel::implementationList() {
// NOTE: allocated on heap so we leak intentionally on exit so we have a chance to clean up the
// individual models without the cleanup of the static list interfering
static auto* libs = new std::vector<LLModel::Implementation>([] () {
std::vector<LLModel::Implementation> fres;
static auto* libs = new std::vector<LLImplementation>([] () {
std::vector<LLImplementation> fres;

auto search_in_directory = [&](const std::string& paths) {
std::stringstream ss(paths);
Expand All @@ -90,10 +90,10 @@ const std::vector<LLModel::Implementation> &LLModel::implementationList() {
// Add to list if model implementation
try {
Dlhandle dl(p.string());
if (!Implementation::isImplementation(dl)) {
if (!LLImplementation::isImplementation(dl)) {
continue;
}
fres.emplace_back(Implementation(std::move(dl)));
fres.emplace_back(LLImplementation(std::move(dl)));
} catch (...) {}
}
}
Expand All @@ -107,7 +107,7 @@ const std::vector<LLModel::Implementation> &LLModel::implementationList() {
return *libs;
}

const LLModel::Implementation* LLModel::implementation(std::ifstream& f, const std::string& buildVariant) {
const LLImplementation* LLModel::implementation(std::ifstream& f, const std::string& buildVariant) {
for (const auto& i : implementationList()) {
f.seekg(0);
if (!i.magicMatch(f)) continue;
Expand All @@ -126,7 +126,7 @@ LLModel *LLModel::construct(const std::string &modelPath, std::string buildVaria
std::ifstream f(modelPath, std::ios::binary);
if (!f) return nullptr;
// Get correct implementation
const LLModel::Implementation* impl = nullptr;
const LLImplementation* impl = nullptr;

#if defined(__APPLE__) && defined(__arm64__) // FIXME: See if metal works for intel macs
if (buildVariant == "auto") {
Expand Down
60 changes: 32 additions & 28 deletions gpt4all-backend/llmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,11 @@
#define LLMODEL_MAX_PROMPT_BATCH 128

class Dlhandle;

class LLImplementation;
class LLModel {
public:
using Token = int32_t;

class Implementation {
LLModel *(*construct_)();

public:
Implementation(Dlhandle&&);
Implementation(const Implementation&) = delete;
Implementation(Implementation&&);
~Implementation();

static bool isImplementation(const Dlhandle&);

std::string_view modelType, buildVariant;
bool (*magicMatch)(std::ifstream& f);
Dlhandle *dlhandle;

// The only way an implementation should be constructed
LLModel *construct() const {
auto fres = construct_();
fres->m_implementation = this;
return fres;
}
};

struct PromptContext {
std::vector<float> logits; // logits of current context
std::vector<int32_t> tokens; // current tokens in the context window
Expand Down Expand Up @@ -74,12 +51,12 @@ class LLModel {
virtual void setThreadCount(int32_t /*n_threads*/) {}
virtual int32_t threadCount() const { return 1; }

const Implementation& implementation() const {
const LLImplementation& implementation() const {
return *m_implementation;
}

static const std::vector<Implementation>& implementationList();
static const Implementation *implementation(std::ifstream& f, const std::string& buildVariant);
static const std::vector<LLImplementation>& implementationList();
static const LLImplementation *implementation(std::ifstream& f, const std::string& buildVariant);
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto");

static void setImplementationsSearchPath(const std::string& path);
Expand All @@ -99,6 +76,33 @@ class LLModel {
// shared by all base classes so it isn't virtual
void recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate);

const Implementation *m_implementation = nullptr;
const LLImplementation *m_implementation = nullptr;

private:
friend class LLImplementation;
};

class LLImplementation {
LLModel *(*construct_)();

public:
LLImplementation(Dlhandle&&);
LLImplementation(const LLImplementation&) = delete;
LLImplementation(LLImplementation&&);
~LLImplementation();

static bool isImplementation(const Dlhandle&);

std::string_view modelType, buildVariant;
bool (*magicMatch)(std::ifstream& f);
Dlhandle *dlhandle;

// The only way an implementation should be constructed
LLModel *construct() const {
auto fres = construct_();
fres->m_implementation = this;
return fres;
}
};

#endif // LLMODEL_H

0 comments on commit 33557b1

Please sign in to comment.