Skip to content

Commit

Permalink
Only set default keep_alive on initial model load
Browse files Browse the repository at this point in the history
This change fixes the handling of keep_alive so that if client
request omits the setting, we only set this on initial load.  Once
the model is loaded, if new requests leave this unset, we'll keep
whatever keep_alive was there.
  • Loading branch information
dhiltgen committed Jul 3, 2024
1 parent ccd7785 commit 955f2a4
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 71 deletions.
31 changes: 29 additions & 2 deletions envconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"errors"
"fmt"
"log/slog"
"math"
"net"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)

type OllamaHost struct {
Expand All @@ -34,7 +36,7 @@ var (
// Set via OLLAMA_HOST in the environment
Host *OllamaHost
// Set via OLLAMA_KEEP_ALIVE in the environment
KeepAlive string
KeepAlive time.Duration
// Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
Expand Down Expand Up @@ -132,6 +134,7 @@ func init() {
NumParallel = 0 // Autoselect
MaxRunners = 0 // Autoselect
MaxQueuedRequests = 512
KeepAlive = 5 * time.Minute

LoadConfig()
}
Expand Down Expand Up @@ -266,7 +269,10 @@ func LoadConfig() {
}
}

KeepAlive = clean("OLLAMA_KEEP_ALIVE")
ka := clean("OLLAMA_KEEP_ALIVE")
if ka != "" {
loadKeepAlive(ka)
}

var err error
ModelsDir, err = getModelsDir()
Expand Down Expand Up @@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) {
Port: port,
}, nil
}

func loadKeepAlive(ka string) {
v, err := strconv.Atoi(ka)
if err != nil {
d, err := time.ParseDuration(ka)
if err == nil {
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
} else {
d := time.Duration(v) * time.Second
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
}
17 changes: 17 additions & 0 deletions envconfig/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package envconfig

import (
"fmt"
"math"
"net"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -23,6 +25,21 @@ func TestConfig(t *testing.T) {
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
LoadConfig()
require.True(t, FlashAttention)
t.Setenv("OLLAMA_KEEP_ALIVE", "")
LoadConfig()
require.Equal(t, 5*time.Minute, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "3")
LoadConfig()
require.Equal(t, 3*time.Second, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
LoadConfig()
require.Equal(t, 1*time.Hour, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
}

func TestClientFromEnvironment(t *testing.T) {
Expand Down
57 changes: 3 additions & 54 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@ import (
"io"
"io/fs"
"log/slog"
"math"
"net"
"net/http"
"net/netip"
"os"
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"syscall"
"time"
Expand Down Expand Up @@ -56,8 +54,6 @@ func init() {
gin.SetMode(mode)
}

var defaultSessionDuration = 5 * time.Minute

func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
Expand Down Expand Up @@ -133,14 +129,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}

var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}

rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
Expand Down Expand Up @@ -320,32 +309,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
streamResponse(c, ch)
}

func getDefaultSessionDuration() time.Duration {
if envconfig.KeepAlive != "" {
v, err := strconv.Atoi(envconfig.KeepAlive)
if err != nil {
d, err := time.ParseDuration(envconfig.KeepAlive)
if err != nil {
return defaultSessionDuration
}

if d < 0 {
return time.Duration(math.MaxInt64)
}

return d
}

d := time.Duration(v) * time.Second
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}

return defaultSessionDuration
}

func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
Expand Down Expand Up @@ -380,14 +343,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}

var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}

rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
Expand Down Expand Up @@ -1318,14 +1274,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}

var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}

rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
Expand Down
14 changes: 10 additions & 4 deletions server/sched.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type LlmRequest struct {
model *Model
opts api.Options
origNumCtx int // Track the initial ctx request
sessionDuration time.Duration
sessionDuration *api.Duration
successCh chan *runnerRef
errCh chan error
schedAttempts uint
Expand Down Expand Up @@ -75,7 +75,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
}

// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
Expand Down Expand Up @@ -389,7 +389,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
runner.expireTimer.Stop()
runner.expireTimer = nil
}
runner.sessionDuration = pending.sessionDuration
if pending.sessionDuration != nil {
runner.sessionDuration = pending.sessionDuration.Duration
}
pending.successCh <- runner
go func() {
<-pending.ctx.Done()
Expand All @@ -402,6 +404,10 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
if numParallel < 1 {
numParallel = 1
}
sessionDuration := envconfig.KeepAlive
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
Expand All @@ -419,7 +425,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
modelPath: req.model.ModelPath,
llama: llama,
Options: &req.opts,
sessionDuration: req.sessionDuration,
sessionDuration: sessionDuration,
gpus: gpus,
estimatedVRAM: llama.EstimatedVRAM(),
estimatedTotal: llama.EstimatedTotal(),
Expand Down
22 changes: 11 additions & 11 deletions server/sched_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestLoad(t *testing.T) {
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
sessionDuration: 2,
sessionDuration: &api.Duration{Duration: 2 * time.Second},
}
// Fail to load model first
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
Expand Down Expand Up @@ -142,7 +142,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
ctx: scenario.ctx,
model: model,
opts: api.DefaultOptions(),
sessionDuration: 5 * time.Millisecond,
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
Expand All @@ -156,18 +156,18 @@ func TestRequests(t *testing.T) {

// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 5 * time.Millisecond
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = 0
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}

// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
tmpModel := *scenario1a.req.model
scenario2a.req.model = &tmpModel
scenario2a.ggml = scenario1a.ggml
scenario2a.req.sessionDuration = 5 * time.Millisecond
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}

// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
Expand Down Expand Up @@ -318,11 +318,11 @@ func TestGetRunner(t *testing.T) {
defer done()

scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = 0
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = 0
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = 0
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
Expand Down Expand Up @@ -402,7 +402,7 @@ func TestPrematureExpired(t *testing.T) {
case <-ctx.Done():
t.Fatal("timeout")
}
time.Sleep(scenario1a.req.sessionDuration)
time.Sleep(scenario1a.req.sessionDuration.Duration)
scenario1a.ctxDone()
time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1)
Expand All @@ -423,7 +423,7 @@ func TestUseLoadedRunner(t *testing.T) {
ctx: ctx,
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
sessionDuration: 2,
sessionDuration: &api.Duration{Duration: 2},
}
finished := make(chan *LlmRequest)
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
Expand Down Expand Up @@ -614,7 +614,7 @@ func TestAlreadyCanceled(t *testing.T) {
dctx, done2 := context.WithCancel(ctx)
done2()
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 0
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
s := InitScheduler(ctx)
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
Expand Down

0 comments on commit 955f2a4

Please sign in to comment.