Skip to content

Commit

Permalink
Add proper streaming support.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ric Szopa authored and Ric Szopa committed Jun 4, 2023
1 parent 99345ed commit c7d549e
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 72 deletions.
62 changes: 5 additions & 57 deletions agent/agents.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,83 +130,31 @@ func (ag *BaseAgent) createRequest(options []Option) (Config, client.ChatComplet
}

func (ag *BaseAgent) Respond(ctx context.Context, options ...Option) (message string, err error) {

logger := log.WithField("agent", ag.name)

logger.Debug("Responding to message")
cfg, req := ag.createRequest(options)

if cfg.Memory != nil {
log.Debug("Using memory")
newMessages, err := cfg.Memory(ctx, cfg, ag.messages)
if err != nil {
log.WithError(err).Error("Failed to use memory")
return "", err
}
ag.messages = newMessages
}

if cfg.Stream() {
return ag.respondStream(ctx, options...)
}

logger.WithField("request", fmt.Sprintf("%+v", req)).Info("Sending request")
resp, err := cfg.Client.CreateChatCompletion(ctx, req)
logger.WithError(err).WithField("response", fmt.Sprintf("%+v", resp)).Debug("Received response from OpenAI API")
logger.WithError(err).WithField("response", fmt.Sprintf("%+v", resp)).Debug("Received response from client")
if err != nil {
logger.WithError(err).Error("Failed to send request to OpenAI API")
return "", err
}
logger.WithField("response", fmt.Sprintf("%+v", resp)).Info("Received response from OpenAI API")
logger.WithField("response", fmt.Sprintf("%+v", resp)).Info("Received response from client")

msg := resp.Choices[0]
ag.Append(msg)

return msg.Content, nil
}

func (ag *BaseAgent) respondStream(ctx context.Context, options ...Option) (string, error) {
return "", nil

// cfg, req := ag.createRequest(options)
// logger := log.WithField("actor", ag.name)

// // FIXME(ryszard): Fix the stream handling. It's currently broken.
// //req.Stream = true
// logger.WithFields(log.Fields{
// "request": fmt.Sprintf("%+v", req),
// "stream": true,
// }).Info("RespondStream: Sending request")
// stream, err := cfg.Client.CreateChatCompletionStream(ctx, req)
// if err != nil {
// return "", err
// }

// defer stream.Close()

// var b strings.Builder

// for {
// r, err := stream.Recv()
// if errors.Is(err, io.EOF) {
// break
// } else if err != nil {
// return "", err
// }
// //logger.WithField("stream response", fmt.Sprintf("%+v", r)).Trace("Received response from OpenAI API")
// delta := r.Choices[0].Delta.Content
// if _, err := b.WriteString(delta); err != nil {
// return "", err
// }
// if _, err := cfg.Output.Write([]byte(delta)); err != nil {
// return "", err
// }

// }
// cfg.Output.Write([]byte("\n\n"))

// message := client.Message{
// Content: b.String(),
// Role: openai.ChatMessageRoleAssistant,
// }

// ag.Append(message)
// return b.String(), nil
}
13 changes: 2 additions & 11 deletions agent/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,10 @@ type Config struct {
// other fields
Client client.Client `json:"-"`

// Output is the writer to which the agent will write its output. If nil,
// the agent's output will be discarded. This is useful for streaming.
Output io.Writer `json:"-"`

// Memory is the agent's memory.
Memory Memory `json:"-"`
}

// Stream returns true if the agent is configured to stream its output.
func (cfg Config) Stream() bool {
return cfg.Output != nil
}

func (ac Config) chatCompletionRequest() client.ChatCompletionRequest {

// FIXME(ryszard): handle Streaming.
Expand Down Expand Up @@ -101,14 +92,14 @@ var _ io.Writer = nullWriter{}
// Note that this will cause the agent to use streaming calls to the OpenAI API.
func WithStreaming(w io.Writer) Option {
return func(ac *Config) {
ac.Output = w
ac.RequestTemplate.Stream = w
}
}

// WithoutStreaming suppresses streaming of the agent's responses.
func WithoutStreaming() Option {
return func(ac *Config) {
ac.Output = nil
ac.RequestTemplate.Stream = nil
}
}

Expand Down
11 changes: 8 additions & 3 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,17 @@ type ChatCompletionRequest struct {
// for the model.
CustomParams map[string]interface{} `json:"params"`

// Stream is a writer to which the API should write the response as it
// appears. The API will still return the response as a whole.
// If Stream is not nil, the client will use the streaming API. The client
// should write the message content from the server as it appears on the
// wire to Stream, and then still return the whole message.
Stream io.Writer `json:"-"` // This should not be used when hashing.
}

// Client is an interface for the OpenAI API client. It's main purpose is to
func (r ChatCompletionRequest) WantsStreaming() bool {
return r.Stream != nil
}

// Client is an interface for the LLM API client. It's main purpose is to
// make testing easier.
type Client interface {
CreateChatCompletion(ctx context.Context, req ChatCompletionRequest) (ChatCompletionResponse, error)
Expand Down
50 changes: 50 additions & 0 deletions client/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ package openai

import (
"context"
"errors"
"fmt"
"io"
"strings"

"github.com/ryszard/agency/client"
"github.com/sashabaranov/go-openai"
log "github.com/sirupsen/logrus"
)

const (
Expand Down Expand Up @@ -44,13 +48,59 @@ func (cl *Client) CreateChatCompletion(ctx context.Context, request client.ChatC
if err != nil {
return client.ChatCompletionResponse{}, err
}
if request.WantsStreaming() {
return cl.createChatCompletionStream(ctx, req, request.Stream)
}
resp, err := cl.client.CreateChatCompletion(ctx, req)
if err != nil {
return client.ChatCompletionResponse{}, err
}
return TranslateResponse(resp), nil
}

func (cl *Client) createChatCompletionStream(ctx context.Context, req openai.ChatCompletionRequest, w io.Writer) (client.ChatCompletionResponse, error) {
req.Stream = true

log.WithFields(log.Fields{
"request": fmt.Sprintf("%+v", req),
"stream": true,
}).Info("RespondStream: Sending request")
stream, err := cl.client.CreateChatCompletionStream(ctx, req)
if err != nil {
return client.ChatCompletionResponse{}, err
}

defer stream.Close()

var b strings.Builder

for {
r, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return client.ChatCompletionResponse{}, err
}
//logger.WithField("stream response", fmt.Sprintf("%+v", r)).Trace("Received response from OpenAI API")
delta := r.Choices[0].Delta.Content
if _, err := b.WriteString(delta); err != nil {
return client.ChatCompletionResponse{}, err
}
if _, err := w.Write([]byte(delta)); err != nil {
return client.ChatCompletionResponse{}, err
}

}
w.Write([]byte("\n\n"))

message := client.Message{
Content: b.String(),
Role: client.Assistant,
}

return client.ChatCompletionResponse{Choices: []client.Message{message}}, nil
}

var roleMapping = map[client.Role]string{
client.User: openai.ChatMessageRoleUser,
client.System: openai.ChatMessageRoleSystem,
Expand Down
8 changes: 7 additions & 1 deletion client/retrying.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func retry(client *retryingClient, fn func() (any, error)) (any, error) {
if err == nil {
return resp, nil
}
log.WithError(err).Error("Error from OpenAI API")
log.WithError(err).Error("Error from the API")
wait := time.Duration(waitMultiplier) * client.baseWait
log.WithField("wait", wait).Info("Waiting before retrying")
if wait > client.maxWait {
Expand Down Expand Up @@ -43,6 +43,9 @@ func Retrying(client Client, baseWait time.Duration, maxWait time.Duration, maxR
if maxRetries == 0 {
panic("maxRetries must not be 0")
}
if client == nil {
panic("client must not be nil")
}
return &retryingClient{
client: client,
baseWait: baseWait,
Expand All @@ -54,7 +57,10 @@ func Retrying(client Client, baseWait time.Duration, maxWait time.Duration, maxR
func (client *retryingClient) CreateChatCompletion(ctx context.Context, req ChatCompletionRequest) (ChatCompletionResponse, error) {
log.WithField("req", req).Info("CreateChatCompletion")
resp, err := retry(client, func() (any, error) {
log.Trace("Calling client.client.CreateChatCompletion")
log.WithField("client", client.client).Trace("here")
return client.client.CreateChatCompletion(ctx, req)

})
if err != nil {
return ChatCompletionResponse{}, err
Expand Down

0 comments on commit c7d549e

Please sign in to comment.