Skip to content

Commit

Permalink
prepend image tags (ollama#2789)
Browse files Browse the repository at this point in the history
instead of appending image tags, prepend them - this generally produces better results
  • Loading branch information
mxyng authored Feb 29, 2024
1 parent fa2f2b3 commit 0e19476
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
8 changes: 5 additions & 3 deletions server/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,15 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
p = prompt{}
}

p.Prompt = msg.Content

var sb strings.Builder
for range msg.Images {
p.Prompt += fmt.Sprintf(" [img-%d]", imgId)
fmt.Fprintf(&sb, "[img-%d] ", imgId)
p.images = append(p.images, imgId)
imgId += 1
}

sb.WriteString(msg.Content)
p.Prompt = sb.String()
case "assistant":
if p.Response != "" {
prompts = append(prompts, p)
Expand Down
6 changes: 3 additions & 3 deletions server/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
},
window: 1024,
want: "You are a Wizard. Hello [img-0]",
want: "You are a Wizard. [img-0] Hello",
},
{
name: "images truncated",
Expand All @@ -165,7 +165,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
},
window: 1024,
want: "You are a Wizard. Hello [img-1]",
want: "You are a Wizard. [img-0] [img-1] Hello",
},
{
name: "empty list",
Expand Down Expand Up @@ -198,7 +198,7 @@ func TestChatPrompt(t *testing.T) {
}

if got != tc.want {
t.Errorf("got = %v, want %v", got, tc.want)
t.Errorf("got: %q, want: %q", got, tc.want)
}
})
}
Expand Down
25 changes: 13 additions & 12 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,19 @@ func GenerateHandler(c *gin.Context) {
slog.Debug("generate handler", "system", req.System)

var sb strings.Builder
for i := range req.Images {
fmt.Fprintf(&sb, "[img-%d] ", i)
}

sb.WriteString(req.Prompt)

p, err := Prompt(req.Template, req.System, sb.String(), "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

sb.Reset()
if req.Context != nil {
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
if err != nil {
Expand All @@ -260,18 +273,6 @@ func GenerateHandler(c *gin.Context) {
sb.WriteString(prev)
}

// write image tags
// TODO: limit the number of images to fit in the context similar to the chat endpoint
for i := range req.Images {
req.Prompt += fmt.Sprintf(" [img-%d]", i)
}

p, err := Prompt(req.Template, req.System, req.Prompt, "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

sb.WriteString(p)

prompt = sb.String()
Expand Down

0 comments on commit 0e19476

Please sign in to comment.