Skip to content

Commit

Permalink
ggml : fix thread-safety of ggml_init and ggml_free
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Oct 29, 2022
1 parent 85d6e1e commit a272f10
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,7 @@ struct ggml_state {

// global state
struct ggml_state g_state;
atomic_bool g_state_barrier = 0;

////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -1265,6 +1266,17 @@ int ggml_up64(int n) {
////////////////////////////////////////////////////////////////////////////////

struct ggml_context * ggml_init(struct ggml_init_params params) {
// make this function thread safe
{
int processing = atomic_fetch_add(&g_state_barrier, 1);
while (processing > 0) {
// wait for other threads to finish
atomic_fetch_sub(&g_state_barrier, 1);
sched_yield();
processing = atomic_fetch_add(&g_state_barrier, 1);
}
}

static bool is_first_call = true;
if (is_first_call) {
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
Expand Down Expand Up @@ -1308,6 +1320,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {

if (ctx == NULL) {
GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);

atomic_fetch_sub(&g_state_barrier, 1);

return NULL;
}

Expand All @@ -1322,10 +1337,25 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {

ggml_assert_aligned(ctx->mem_buffer);

GGML_PRINT_DEBUG("%s: context initialized\n", __func__);

atomic_fetch_sub(&g_state_barrier, 1);

return ctx;
}

void ggml_free(struct ggml_context * ctx) {
// make this function thread safe
{
int processing = atomic_fetch_add(&g_state_barrier, 1);
while (processing > 0) {
// wait for other threads to finish
atomic_fetch_sub(&g_state_barrier, 1);
sched_yield();
processing = atomic_fetch_add(&g_state_barrier, 1);
}
}

for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
if (&g_state.contexts[i].context == ctx) {
g_state.contexts[i].used = false;
Expand All @@ -1337,11 +1367,15 @@ void ggml_free(struct ggml_context * ctx) {
free(ctx->mem_buffer);
}

atomic_fetch_sub(&g_state_barrier, 1);

return;
}
}

GGML_PRINT_DEBUG("%s: context not found\n", __func__);

atomic_fetch_sub(&g_state_barrier, 1);
}

size_t ggml_used_mem(const struct ggml_context * ctx) {
Expand Down

0 comments on commit a272f10

Please sign in to comment.