Skip to content

Commit

Permalink
do no automatically aggregate system messages
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed Jul 11, 2024
1 parent 791650d commit e64f9eb
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 23 deletions.
39 changes: 20 additions & 19 deletions template/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,21 @@ var response = parse.ActionNode{
},
}

var funcs = template.FuncMap{
"aggregate": func(v []*api.Message, role string) string {
var aggregated []string
for _, m := range v {
if m.Role == role {
aggregated = append(aggregated, m.Content)
}
}

return strings.Join(aggregated, "\n\n")
},
}

func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero")
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)

tmpl, err := tmpl.Parse(s)
if err != nil {
Expand Down Expand Up @@ -149,23 +162,21 @@ type Values struct {
}

func (t *Template) Execute(w io.Writer, v Values) error {
system, collated := collate(v.Messages)
collated := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": collated,
})
}

var b bytes.Buffer
var prompt, response string
var system, prompt, response string
for i, m := range collated {
switch m.Role {
case "system":
system = m.Content
case "user":
prompt = m.Content
if i != 0 {
system = ""
}
case "assistant":
response = m.Content
}
Expand All @@ -179,6 +190,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
return err
}

system = ""
prompt = ""
response = ""
}
Expand Down Expand Up @@ -209,25 +221,14 @@ func (t *Template) Execute(w io.Writer, v Values) error {
return err
}

type messages []*api.Message

// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system"
// which are templated separately. As a side effect, it mangles message content adding image
// tags ([img-%d]) as needed
func collate(msgs []api.Message) (system string, collated messages) {
func collate(msgs []api.Message) (collated []*api.Message) {
var n int
for i := range msgs {
msg := msgs[i]
if msg.Role == "system" {
if system != "" {
system += "\n\n"
}

system += msg.Content
continue
}

for range msg.Images {
imageTag := fmt.Sprintf("[img-%d]", n)
if !strings.Contains(msg.Content, "[img]") {
Expand Down
11 changes: 7 additions & 4 deletions template/template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ func TestTemplate(t *testing.T) {
})

t.Run("legacy", func(t *testing.T) {
t.Skip("legacy outputs are currently default outputs")
var legacy bytes.Buffer
if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -154,11 +155,13 @@ func TestParse(t *testing.T) {
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
{`{{- range .Messages }}
{{- if eq .Role "system" }}SYSTEM:
{{- else if eq .Role "user" }}USER:
{{- else if eq .Role "assistant" }}ASSISTANT:
{{- end }} {{ .Content }}
{{- end }}`, []string{"content", "messages", "role"}},
{`{{- if .Messages }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
Expand Down

0 comments on commit e64f9eb

Please sign in to comment.