Skip to content

Commit

Permalink
Merge pull request ollama#5653 from ollama/mxyng/collect-system
Browse files Browse the repository at this point in the history
template: preprocess message and collect system
  • Loading branch information
mxyng authored Jul 12, 2024
2 parents 3362733 + 36c87c4 commit e5c65a8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 67 deletions.
37 changes: 15 additions & 22 deletions template/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,8 @@ var response = parse.ActionNode{
},
}

var funcs = template.FuncMap{
// contents returns the contents of messages with an optional role filter
"contents": func(v []*api.Message, role ...string) string {
var parts []string
for _, m := range v {
if len(role) == 0 || role[0] == "" || m.Role == role[0] {
parts = append(parts, m.Content)
}
}

return strings.Join(parts, "\n\n")
},
}

func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
tmpl := template.New("").Option("missingkey=zero")

tmpl, err := tmpl.Parse(s)
if err != nil {
Expand Down Expand Up @@ -163,15 +149,16 @@ type Values struct {
}

func (t *Template) Execute(w io.Writer, v Values) error {
collated := collate(v.Messages)
system, collated := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": collated,
})
}

var b bytes.Buffer
var system, prompt, response string
var prompt, response string
for i, m := range collated {
switch m.Role {
case "system":
Expand Down Expand Up @@ -223,11 +210,13 @@ func (t *Template) Execute(w io.Writer, v Values) error {
}

// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system"
// which are templated separately. As a side effect, it mangles message content adding image
// tags ([img-%d]) as needed
func collate(msgs []api.Message) (collated []*api.Message) {
// into a single message. collate also collects and returns all system messages.
// collate mutates message content adding image tags ([img-%d]) as needed
func collate(msgs []api.Message) (string, []*api.Message) {
var n int

var system []string
var collated []*api.Message
for i := range msgs {
msg := msgs[i]
for range msg.Images {
Expand All @@ -240,14 +229,18 @@ func collate(msgs []api.Message) (collated []*api.Message) {
n++
}

if msg.Role == "system" {
system = append(system, msg.Content)
}

if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
collated[len(collated)-1].Content += "\n\n" + msg.Content
} else {
collated = append(collated, &msg)
}
}

return
return strings.Join(system, "\n\n"), collated
}

func parseNode(n parse.Node) []string {
Expand Down
53 changes: 8 additions & 45 deletions template/template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,11 @@ func TestExecuteWithMessages(t *testing.T) {
{"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- $system := contents .Messages "system" -}}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
{{- $system = "" }}
{"messages", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{{- end }}`},
},
Values{
Expand All @@ -243,13 +241,11 @@ func TestExecuteWithMessages(t *testing.T) {
{"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- $system := contents .Messages "system" -}}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
{{- $system = "" }}
{"messages", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{{- end }}`},
},
Values{
Expand Down Expand Up @@ -363,36 +359,3 @@ Answer: `,
})
}
}

func TestFuncs(t *testing.T) {
t.Run("contents", func(t *testing.T) {
cases := map[string]string{
"": "A\n\nB\n\nC\n\nD\n\nE\n\nF",
"system": "A\n\nF",
"user": "B\n\nE",
"assistant": "C\n\nD",
}

s := []*api.Message{
{Role: "system", Content: "A"},
{Role: "user", Content: "B"},
{Role: "assistant", Content: "C"},
{Role: "assistant", Content: "D"},
{Role: "user", Content: "E"},
{Role: "system", Content: "F"},
}

fn, ok := funcs["contents"].(func([]*api.Message, ...string) string)
if !ok {
t.Fatal("contents is not a function")
}

for k, v := range cases {
t.Run(k, func(t *testing.T) {
if diff := cmp.Diff(fn(s, k), v); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
})
}

0 comments on commit e5c65a8

Please sign in to comment.