Skip to content

Commit

Permalink
whisper : fix extra memory usage after recent processor changes
Browse files Browse the repository at this point in the history
Had increased the memory buffer to the size of the model and forgot to
bring it down.
  • Loading branch information
ggerganov committed Nov 2, 2022
1 parent c63ce24 commit 02dfd5b
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,19 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
static const size_t MB = 1024*1024;

static const std::map<e_model, size_t> MEM_REQ_MODEL = {
{ MODEL_TINY, 86ull*MB },
{ MODEL_BASE, 165ull*MB },
{ MODEL_SMALL, 540ull*MB },
{ MODEL_MEDIUM, 1650ull*MB },
{ MODEL_LARGE, 3260ull*MB },
{ MODEL_TINY, 74ull*MB },
{ MODEL_BASE, 142ull*MB },
{ MODEL_SMALL, 466ull*MB },
{ MODEL_MEDIUM, 1464ull*MB },
{ MODEL_LARGE, 2952ull*MB },
};

static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
{ MODEL_TINY, 12ull*MB },
{ MODEL_BASE, 24ull*MB },
{ MODEL_SMALL, 70ull*MB },
{ MODEL_MEDIUM, 184ull*MB },
{ MODEL_LARGE, 306ull*MB },
};

static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
Expand Down Expand Up @@ -498,7 +506,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {

wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize(std::max(MEM_REQ_MODEL.at(model.type), MEM_REQ_MODEL.at(model.type))); // TODO: TMP !!!
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));

Expand Down Expand Up @@ -722,20 +730,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
}
}

// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};

model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}

// prepare memory for the weights
{
auto & ctx = model.ctx;
Expand Down Expand Up @@ -932,6 +926,20 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
}
}

// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};

model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}

// key + value memory
{
auto & ctx = model.ctx_mem;
Expand Down

0 comments on commit 02dfd5b

Please sign in to comment.