Skip to content

Commit

Permalink
adding flag for pinning seed in openai and compatible APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardo1980 committed Sep 20, 2024
1 parent f4044cd commit a619c91
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 8 deletions.
2 changes: 2 additions & 0 deletions cli/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type Flags struct {
Language string `short:"g" long:"language" description:"Specify the Language Code for the chat, e.g. -g=en -g=zh" default:""`
ScrapeURL string `short:"u" long:"scrape_url" description:"Scrape website URL to markdown using Jina AI"`
ScrapeQuestion string `short:"q" long:"scrape_question" description:"Search question using Jina AI"`
Seed int `short:"e" long:"seed" description:"Seed to be used for LMM generation"`
}

// Init Initialize flags. returns a Flags struct and an error
Expand Down Expand Up @@ -99,6 +100,7 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
PresencePenalty: o.PresencePenalty,
FrequencyPenalty: o.FrequencyPenalty,
Raw: o.Raw,
Seed: o.Seed,
}
return
}
Expand Down
22 changes: 22 additions & 0 deletions cli/flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func TestBuildChatOptions(t *testing.T) {
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Seed: 1,
}

expectedOptions := &common.ChatOptions{
Expand All @@ -61,6 +62,27 @@ func TestBuildChatOptions(t *testing.T) {
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Raw: false,
Seed: 1,
}
options := flags.BuildChatOptions()
assert.Equal(t, expectedOptions, options)
}

func TestBuildChatOptionsDefaultSeed(t *testing.T) {
flags := &Flags{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
}

expectedOptions := &common.ChatOptions{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Raw: false,
Seed: 0,
}
options := flags.BuildChatOptions()
assert.Equal(t, expectedOptions, options)
Expand Down
1 change: 1 addition & 0 deletions common/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type ChatOptions struct {
PresencePenalty float64
FrequencyPenalty float64
Raw bool
Seed int
}

// NormalizeMessages remove empty messages and ensure messages order user-assist-user
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ require (
github.com/samber/lo v1.47.0
github.com/sashabaranov/go-openai v1.30.0
github.com/stretchr/testify v1.9.0
golang.org/x/text v0.18.0
google.golang.org/api v0.197.0
)

Expand Down Expand Up @@ -61,7 +62,6 @@ require (
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.6.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
Expand Down
28 changes: 21 additions & 7 deletions vendors/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"

"github.com/danielmiessler/fabric/common"
"github.com/samber/lo"
Expand Down Expand Up @@ -111,6 +112,7 @@ func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.
}
if len(resp.Choices) > 0 {
ret = resp.Choices[0].Message.Content
slog.Debug("SystemFingerprint: " + resp.SystemFingerprint)
}
return
}
Expand All @@ -128,13 +130,25 @@ func (o *Client) buildChatCompletionRequest(
Messages: messages,
}
} else {
ret = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: messages,
if opts.Seed == 0 {
ret = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: messages,
}
} else {
ret = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: messages,
Seed: &opts.Seed,
}
}
}
return
Expand Down
102 changes: 102 additions & 0 deletions vendors/openai/openai_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package openai

import (
"testing"

"github.com/danielmiessler/fabric/common"
"github.com/sashabaranov/go-openai"
goopenai "github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)

func TestBuildChatCompletionRequestPinSeed(t *testing.T) {

var msgs []*common.Message

for i := 0; i < 2; i++ {
msgs = append(msgs, &common.Message{
Role: "User",
Content: "My msg",
})
}

opts := &common.ChatOptions{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Raw: false,
Seed: 1,
}

var expectedMessages []openai.ChatCompletionMessage

for i := 0; i < 2; i++ {
expectedMessages = append(expectedMessages,
openai.ChatCompletionMessage{
Role: msgs[i].Role,
Content: msgs[i].Content,
},
)
}

var expectedRequest = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: expectedMessages,
Seed: &opts.Seed,
}

var client = NewClient()
request := client.buildChatCompletionRequest(msgs, opts)
assert.Equal(t, expectedRequest, request)
}

func TestBuildChatCompletionRequestNilSeed(t *testing.T) {

var msgs []*common.Message

for i := 0; i < 2; i++ {
msgs = append(msgs, &common.Message{
Role: "User",
Content: "My msg",
})
}

opts := &common.ChatOptions{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Raw: false,
Seed: 0,
}

var expectedMessages []openai.ChatCompletionMessage

for i := 0; i < 2; i++ {
expectedMessages = append(expectedMessages,
openai.ChatCompletionMessage{
Role: msgs[i].Role,
Content: msgs[i].Content,
},
)
}

var expectedRequest = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: expectedMessages,
Seed: nil,
}

var client = NewClient()
request := client.buildChatCompletionRequest(msgs, opts)
assert.Equal(t, expectedRequest, request)
}

0 comments on commit a619c91

Please sign in to comment.