Skip to content

Commit

Permalink
Implement Inject.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ric Szopa authored and Ric Szopa committed Jun 6, 2023
1 parent a351c07 commit 2f53a7a
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
18 changes: 18 additions & 0 deletions agent/agents.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ type Agent interface {
// call, but won't affect subsequent calls.
Respond(ctx context.Context, options ...Option) (message string, err error)

// Inject introduces a message into the ongoing conversation, giving the
// impression that the agent produced it. This method returns the message as
// processed by the agent, which will be identical to the input in simple
// cases. However, depending on the agent's behavior, the returned message
// may be different from the input.
Inject(message string, data ...any) (string, error)

// Messages returns all messages that the agent has sent and received.
Messages() []client.Message

Expand Down Expand Up @@ -118,6 +125,17 @@ func (ag *BaseAgent) Listen(message string, data ...any) (string, error) {
return message, nil
}

func (ag *BaseAgent) Inject(message string, data ...any) (string, error) {
if len(data) > 0 {
return "", errors.New("this agent does not support passing data to Inject")
}
ag.Append(client.Message{
Content: message,
Role: client.Assistant,
})
return message, nil
}

func (ag *BaseAgent) createRequest(options []Option) (Config, client.ChatCompletionRequest) {
cfg := ag.config.clone()
for _, opt := range options {
Expand Down
54 changes: 37 additions & 17 deletions agent/agents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,43 @@ import (
)

func TestAgentSystem(t *testing.T) {
t.Run("System message is added to the actor's messages", func(t *testing.T) {
ag := &BaseAgent{}
systemMessage := "Test system message"

ag.System(systemMessage)

want := []client.Message{
{
Content: systemMessage,
Role: "system",
},
}

if !reflect.DeepEqual(ag.Messages(), want) {
t.Errorf("got %v, want %v", ag.Messages(), want)
}
})

ag := &BaseAgent{}
systemMessage := "Test system message"

ag.System(systemMessage)

want := []client.Message{
{
Content: systemMessage,
Role: client.System,
},
}

if !reflect.DeepEqual(ag.Messages(), want) {
t.Errorf("got %v, want %v", ag.Messages(), want)
}

}

func TestAgentInject(t *testing.T) {

ag := &BaseAgent{}
injected := "Fake agent message"

_, _ = ag.Inject(injected)

want := []client.Message{
{
Content: injected,
Role: client.Assistant,
},
}

if !reflect.DeepEqual(ag.Messages(), want) {
t.Errorf("got %v, want %v", ag.Messages(), want)
}

}

type MockClient struct {
Expand Down
9 changes: 9 additions & 0 deletions agent/templated.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,12 @@ func (ag *TemplatedAgent) System(templateName string, data ...any) (string, erro

return ag.Agent.System(p)
}

func (ag *TemplatedAgent) Inject(templateName string, data ...any) (string, error) {
p, err := ag.getPrompt(templateName, data...)
if err != nil {
return "", err
}

return ag.Agent.Inject(p)
}

0 comments on commit 2f53a7a

Please sign in to comment.