Skip to content

Commit

Permalink
Memory now takes a context.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ric Szopa authored and Ric Szopa committed May 29, 2023
1 parent c85115e commit d005841
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions agent/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ import (
//
// A memory function should not drop any system messages, and should not drop
// the last user message. If this is impossible, it should return an error.
type Memory func(Config, []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error)
type Memory func(context.Context, Config, []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error)

// BufferMemory is a memory that keeps the last n messages. All the system
// messages will be kept. If the buffer size is too small, it will return an
// error.
func BufferMemory(n int) Memory {
return func(cfg Config, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
return func(ctx context.Context, cfg Config, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {

// if there's less or equal to n messages, return all the messages
if len(messages) <= n {
Expand Down Expand Up @@ -207,7 +207,7 @@ func TokenBufferMemory(fillRatio float64) Memory {
panic("fillRatio must be in the range (0, 1]")
}

return func(cfg Config, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
return func(ctx context.Context, cfg Config, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
maxTokens := int(float64(cfg.MaxTokens) * fillRatio)

newMessages, droppedMessages, err := partitionByTokenLimit(cfg, messages, maxTokens, tokenCount)
Expand Down Expand Up @@ -273,7 +273,7 @@ func SummarizerMemoryWithTemplate(fillRatio float64, tmpl *template.Template, op
panic("fillRatio must be in the range (0, 1]")
}

return func(cfg Config, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
return func(ctx context.Context, cfg Config, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
maxTokens := int(float64(cfg.MaxTokens) * fillRatio)

retainedMessages, droppedMessages, err := partitionByTokenLimit(cfg, messages, maxTokens, tokenCount)
Expand Down Expand Up @@ -327,9 +327,7 @@ func SummarizerMemoryWithTemplate(fillRatio float64, tmpl *template.Template, op
return nil, err
}

// FIXME(ryszard): pass the context from outside.

summary, err := summarizer.Respond(context.TODO())
summary, err := summarizer.Respond(ctx)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit d005841

Please sign in to comment.