Skip to content

Commit

Permalink
ReAct framework.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ric Szopa authored and Ric Szopa committed May 29, 2023
1 parent 6af47cb commit c85115e
Show file tree
Hide file tree
Showing 6 changed files with 593 additions and 0 deletions.
104 changes: 104 additions & 0 deletions agent/react/agent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package react

import (
"bufio"
"context"
_ "embed"
"fmt"
"os"

"github.com/ryszard/agency/agent"
"github.com/ryszard/agency/util/python"
log "github.com/sirupsen/logrus"
)

//go:embed react_prompt.md
var systemPrompt string

func Work(ctx context.Context, client agent.Client, pythonPath string, cache agent.Cache, question string, options ...agent.Option) error {
ag := agent.New("pythonista", options...)

ag = agent.Cached(ag, cache)

err := ag.System(systemPrompt)
if err != nil {
return err
}

python, err := python.New(pythonPath)
if err != nil {
return err
}

defer python.Close()

steps := []Step{}

_, err = ag.Listen(fmt.Sprintf("Question: %s", question))
if err != nil {
return err
}

for {
msg, err := ag.Respond(context.Background())
if err != nil {
return err
}

log.WithField("msg", msg).Info("received message")

newSteps, err := Parse(msg)
if err != nil {
return err
}
log.WithField("newSteps", fmt.Sprintf("%+v", newSteps)).Info("parsed message")
//steps = append(steps, newSteps...)
for _, step := range newSteps {
fmt.Printf("%s\n", step)
}

steps = append(steps, newSteps...)
lastStep := steps[len(steps)-1]

if lastStep.Type == FinalAnswerStep {
return nil
} else if lastStep.Type != ActionStep {
_, err := ag.Listen("Please continue.")
if err != nil {
return err
}
continue
}
var observation string
switch lastStep.Argument {
case "python":
stdout, stderr, err := python.Execute(lastStep.Content)
if err != nil {
return err
}
fmt.Printf("stdout: %s\nstderr: %s\n", stdout, stderr)
observation = fmt.Sprintf("Observation: \nStandard Output: %s\nStandardError:\n%s\n", stdout, stderr)

case "human":
// Print the question
fmt.Printf("Question to Human: %s\n", question)
// Read the answer from standard input
reader := bufio.NewReader(os.Stdin)
fmt.Print("Answer: ")
answer, err := reader.ReadString('\n')
if err != nil {
return err
}
observation = fmt.Sprintf("Observation: Answer from human: %s\n", answer)

}

if _, err := ag.Listen(observation); err != nil {
return err
}

fmt.Println("\n" + observation + "\n")

}
return nil
}
134 changes: 134 additions & 0 deletions agent/react/conversation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package react

import (
"fmt"
"strings"
)

// StepType represents the type of step in the conversation.
type StepType string

const (
ThoughtStep StepType = "Thought"
ActionStep StepType = "Action"
ObservationStep StepType = "Observation"
QuestionStep StepType = "Question"
AssumptionStep StepType = "Assumption"
AnswerStep StepType = "Answer"
FinalAnswerStep StepType = "Final Answer"
Unrecognized StepType = ""
)

func (s StepType) Prefix() string {
return fmt.Sprintf("%s: ", s)
}

func (s StepType) IsRecognized() bool {
return s != Unrecognized
}

func matchStep(line string) StepType {
switch {
case strings.HasPrefix(line, ThoughtStep.Prefix()):
return ThoughtStep
case strings.HasPrefix(line, ActionStep.Prefix()):
return ActionStep
case strings.HasPrefix(line, AssumptionStep.Prefix()):
return AssumptionStep
case strings.HasPrefix(line, ObservationStep.Prefix()):
return ObservationStep
case strings.HasPrefix(line, QuestionStep.Prefix()):
return QuestionStep
case strings.HasPrefix(line, AnswerStep.Prefix()):
return AnswerStep
case strings.HasPrefix(line, FinalAnswerStep.Prefix()):
return FinalAnswerStep
default:
return Unrecognized
}
}

// Step represents an individual step in the conversation or process log, which includes
// a thought, action, action input, observation or final answer.
type Step struct {
Type StepType
Content string
Argument string
}

func (s Step) String() string {
if s.Type == ActionStep {
return fmt.Sprintf("%s: %s\n%s", s.Type, s.Argument, s.Content)
}
return fmt.Sprintf("%s: %s", s.Type, s.Content)
}

// Conversation represents a sequence of conversation steps
type Conversation struct {
Question string
Steps []Step
}

// NewConversation returns an empty conversation for a given question
func NewConversation(question string) *Conversation {
return &Conversation{Question: question}
}

// Parse parses a conversation from a string.
func Parse(text string) (steps []Step, err error) {
// Split the text into lines
lines := strings.Split(text, "\n")

currentStep := Step{}

for _, line := range lines {
stepType := matchStep(line)

if stepType.IsRecognized() {
// We found the beginning of a new step.
if currentStep.Type.IsRecognized() {
// There was a previous step in progress, so we should finalize
// it and add it to the list of steps.
steps = append(steps, currentStep)
}
currentStep = Step{Type: stepType}
if stepType != ActionStep {
currentStep.Content = strings.TrimSpace(strings.TrimPrefix(line, stepType.Prefix()))
} else {
// Split the line into the step type and the argument the first
// value returned by strings.Cut is going to be the step type.
// `found` is always going to be true, as otherwise the step
// type would have been Unrecognized
_, arg, _ := strings.Cut(line, ": ")

currentStep.Argument = strings.TrimSpace(arg)
}

} else if currentStep.Type.IsRecognized() {
// We are in the middle of a step; add the line to its content.
currentStep.Content += "\n" + line
} else {
// No new step, and no step in progress. Unless this is a blank
// we should return an error.
if strings.TrimSpace(line) != "" {
return nil, fmt.Errorf("unrecognized step: %q", line)
}
}
}

steps = append(steps, currentStep)

for i, step := range steps {

steps[i].Content = strings.TrimSpace(step.Content)

}

return steps, nil
}

// FinalAnswer is a method on Conversation that returns the final answer from the conversation steps
func (c *Conversation) FinalAnswer() (string, error) {
// implementation goes here
return "", nil
}
150 changes: 150 additions & 0 deletions agent/react/conversation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package react

import (
"strings"
"testing"
)

func TestAppendStepsFromText_Action(t *testing.T) {
//log.SetLevel(log.TraceLevel)

for _, tt := range []struct {
name string
text string
want []Step
}{
{
name: "basic",
text: `Question: Is the Python version used by the interpreter a stable release?
Assumption: I can use the Python interpreter
Thought: The version of the Python interpreter can be determined using the sys module in Python.
Action: python
import sys
sys.version`,
want: []Step{
{Type: QuestionStep, Content: "Is the Python version used by the interpreter a stable release?"},
{Type: AssumptionStep, Content: "I can use the Python interpreter"},
{Type: ThoughtStep, Content: "The version of the Python interpreter can be determined using the sys module in Python."},
{Type: ActionStep, Argument: "python", Content: "import sys\nsys.version"},
},
},
{
name: "line containing only tabs, empty lines",
text: `
Question: Is the Python version used by the interpreter a stable release?
Thought: The version of the Python interpreter can be determined using the sys module in Python.
Action: python
import sys
sys.version`,
want: []Step{
{Type: QuestionStep, Content: "Is the Python version used by the interpreter a stable release?"},
{Type: ThoughtStep, Content: "The version of the Python interpreter can be determined using the sys module in Python."},
{Type: ActionStep, Argument: "python", Content: "import sys\nsys.version"},
},
},
{
name: "thoughts, questions, actions",
text: `Thought: The Python interpreter is using version 3.8.5.
Question: Is Python version 3.8.5 a stable release?
Thought: Stable releases of Python usually have a version number with two parts (major.minor) or three parts (major.minor.micro) if the micro version is zero. If the micro version is greater than zero, it is usually a bug fix release which is also considered stable.
Action: python
version_parts = tuple(map(int, '3.8.5'.split('.')))
len(version_parts) in {2, 3} and (len(version_parts) != 3 or version_parts[2] == 0)`,
want: []Step{
{Type: ThoughtStep, Content: "The Python interpreter is using version 3.8.5."},
{Type: QuestionStep, Content: "Is Python version 3.8.5 a stable release?"},
{Type: ThoughtStep, Content: "Stable releases of Python usually have a version number with two parts (major.minor) or three parts (major.minor.micro) if the micro version is zero. If the micro version is greater than zero, it is usually a bug fix release which is also considered stable."},
{Type: ActionStep, Argument: "python", Content: "version_parts = tuple(map(int, '3.8.5'.split('.')))\nlen(version_parts) in {2, 3} and (len(version_parts) != 3 or version_parts[2] == 0)"},
}},
{
name: "python indentation",
text: `Question: Is the Python version used by the interpreter a stable release?
Thought: The version of the Python interpreter can be determined using the sys module in Python.
Action: python
def foo():
return 1`,

want: []Step{
{Type: QuestionStep, Content: "Is the Python version used by the interpreter a stable release?"},
{Type: ThoughtStep, Content: "The version of the Python interpreter can be determined using the sys module in Python."},
{Type: ActionStep, Argument: "python", Content: "def foo():\n return 1"},
},
},

{
name: "python don't lose empty lines",
text: `Question: Is the Python version used by the interpreter a stable release?
Thought: The version of the Python interpreter can be determined using the sys module in Python.
Action: python
def foo():
return 1
foo()`,

want: []Step{
{Type: QuestionStep, Content: "Is the Python version used by the interpreter a stable release?"},
{Type: ThoughtStep, Content: "The version of the Python interpreter can be determined using the sys module in Python."},
{Type: ActionStep, Argument: "python", Content: "def foo():\n return 1\n\nfoo()"},
},
},
{
name: "multiline question and thought",
text: `Question: Is the Python version used by the interpreter a stable release?
A lot depends on that.
Thought: The version of the Python interpreter can be determined using the sys module in Python.
Let's give it a try.
Action: python
import sys
sys.version`,
want: []Step{
{Type: QuestionStep, Content: "Is the Python version used by the interpreter a stable release?\n\nA lot depends on that."},
{Type: ThoughtStep, Content: "The version of the Python interpreter can be determined using the sys module in Python.\n\nLet's give it a try."},
{Type: ActionStep, Argument: "python", Content: "import sys\nsys.version"},
},
},
} {

steps, err := Parse(tt.text)
if err != nil {
t.Fatalf("%s: unexpected error: %v", tt.name, err)
}

for _, step := range steps {
step.Content = strings.TrimSpace(step.Content)
}

for i, step := range steps {
if step.Content != strings.TrimSpace(step.Content) {
t.Fatalf("%s: unexpected whitespace at step %d: %q", tt.name, i, step.Content)
}
}

if len(steps) != len(tt.want) {
t.Errorf("expected %d steps, got %d", len(tt.want), len(steps))
}

for i, step := range steps {
if step.Type != tt.want[i].Type {
t.Errorf("%s: unexpected step at index %d: type %s (want type %s)", tt.name, i, step.Type, tt.want[i].Type)
}

if step.Argument != tt.want[i].Argument {
t.Errorf("%s: unexpected step at index %d: argument %q (want argument %q)", tt.name, i, step.Argument, tt.want[i].Argument)
}

if step.Content != tt.want[i].Content {
t.Errorf("%s: unexpected step at index %d: content %q (want content %q)", tt.name, i, step.Content, tt.want[i].Content)
}

}
}
}
Loading

0 comments on commit c85115e

Please sign in to comment.