diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 492f2d0af21..0e30c36fe2b 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -82,6 +82,8 @@ type Backend struct { meta *fs.GGML cpus, gpus []Context tensors map[string]*Context + + sched *C.struct_ggml_backend_sched } func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { @@ -182,10 +184,24 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { return nil, err } + backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus)) + bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus)) + for i, c := range append(gpus, cpus...) { + backends[i] = c.backend + bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend) + } + return &Backend{ meta: meta, cpus: cpus, gpus: gpus, + sched: C.ggml_backend_sched_new( + (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])), + (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])), + C.int(len(backends)), + C.size_t(max(8192, len(meta.Tensors().Items())*5)), + true, + ), }, nil } @@ -219,31 +235,23 @@ func (b *Backend) NewContext() ml.Context { }) backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus)) - bufts := make([]*C.struct_ggml_backend_buffer_type, len(b.gpus)+len(b.cpus)) for i, c := range append(b.gpus, b.cpus...) { backends[i] = c.backend - bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend) } return &Context{ + b: b, ctx: c, backend: backends[0], nodes: nodes, - sched: C.ggml_backend_sched_new( - (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])), - (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])), - C.int(len(backends)), - C.size_t(nodes), - true, - ), } } type Context struct { + b *Backend ctx *C.struct_ggml_context backend *C.struct_ggml_backend - sched *C.struct_ggml_backend_sched graph *C.struct_ggml_cgraph nodes int } @@ -257,12 +265,13 @@ func (c *Context) Forward(t ml.Tensor) { } func (c *Context) Compute(tensors ...ml.Tensor) { - C.ggml_backend_sched_graph_compute_async(c.sched, c.graph) + C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph) + C.ggml_backend_sched_reset(c.b.sched) needSync := true sync := func() { if needSync { - C.ggml_backend_sched_synchronize(c.sched) + C.ggml_backend_sched_synchronize(c.b.sched) needSync = false } } @@ -350,7 +359,6 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { func (c *Context) Close() { if c != nil { - C.ggml_backend_sched_free(c.sched) C.ggml_free(c.ctx) } }