Skip to content

Commit

Permalink
restruct: restruct channel factories
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Mar 12, 2024
1 parent 3335e81 commit 422da58
Show file tree
Hide file tree
Showing 52 changed files with 283 additions and 515 deletions.
185 changes: 24 additions & 161 deletions adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,183 +4,46 @@ import (
"chat/adapter/azure"
"chat/adapter/baichuan"
"chat/adapter/bing"
"chat/adapter/chatgpt"
"chat/adapter/claude"
"chat/adapter/common"
"chat/adapter/dashscope"
"chat/adapter/hunyuan"
"chat/adapter/midjourney"
"chat/adapter/openai"
"chat/adapter/palm2"
"chat/adapter/skylark"
"chat/adapter/slack"
"chat/adapter/sparkdesk"
"chat/adapter/zhinao"
"chat/adapter/zhipuai"
"chat/globals"
"chat/utils"
"fmt"
)

type RequestProps struct {
MaxRetries *int
Current int
Group string
var channelFactories = map[string]adaptercommon.FactoryCreator{
globals.OpenAIChannelType: openai.NewChatInstanceFromConfig,
globals.AzureOpenAIChannelType: azure.NewChatInstanceFromConfig,
globals.ClaudeChannelType: claude.NewChatInstanceFromConfig,
globals.SlackChannelType: slack.NewChatInstanceFromConfig,
globals.BingChannelType: bing.NewChatInstanceFromConfig,
globals.PalmChannelType: palm2.NewChatInstanceFromConfig,
globals.SparkdeskChannelType: sparkdesk.NewChatInstanceFromConfig,
globals.ChatGLMChannelType: zhipuai.NewChatInstanceFromConfig,
globals.QwenChannelType: dashscope.NewChatInstanceFromConfig,
globals.HunyuanChannelType: hunyuan.NewChatInstanceFromConfig,
globals.BaichuanChannelType: baichuan.NewChatInstanceFromConfig,
globals.SkylarkChannelType: skylark.NewChatInstanceFromConfig,
globals.ZhinaoChannelType: zhinao.NewChatInstanceFromConfig,
globals.MidjourneyChannelType: midjourney.NewChatInstanceFromConfig,
}

type ChatProps struct {
RequestProps
func createChatRequest(conf globals.ChannelConfig, props *adaptercommon.ChatProps, hook globals.Hook) error {
props.Model = conf.GetModelReflect(props.OriginalModel)

Model string
Message []globals.Message
MaxTokens *int
PresencePenalty *float32
FrequencyPenalty *float32
RepetitionPenalty *float32
Temperature *float32
TopP *float32
TopK *int
Tools *globals.FunctionTools
ToolChoice *interface{}
Buffer utils.Buffer
}

func createChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.Hook) error {
model := conf.GetModelReflect(props.Model)

switch conf.GetType() {
case globals.OpenAIChannelType:
return chatgpt.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&chatgpt.ChatProps{
Model: model,
Message: props.Message,
Token: props.MaxTokens,
PresencePenalty: props.PresencePenalty,
FrequencyPenalty: props.FrequencyPenalty,
Temperature: props.Temperature,
TopP: props.TopP,
Tools: props.Tools,
ToolChoice: props.ToolChoice,
Buffer: props.Buffer,
}, hook)

case globals.AzureOpenAIChannelType:
return azure.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&azure.ChatProps{
Model: model,
Message: props.Message,
Token: props.MaxTokens,
PresencePenalty: props.PresencePenalty,
FrequencyPenalty: props.FrequencyPenalty,
Temperature: props.Temperature,
TopP: props.TopP,
Tools: props.Tools,
ToolChoice: props.ToolChoice,
Buffer: props.Buffer,
}, hook)

case globals.ClaudeChannelType:
return claude.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&claude.ChatProps{
Model: model,
Message: props.Message,
Token: props.MaxTokens,
TopP: props.TopP,
TopK: props.TopK,
Temperature: props.Temperature,
}, hook)

case globals.SlackChannelType:
return slack.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&slack.ChatProps{
Message: props.Message,
}, hook)

case globals.BingChannelType:
return bing.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&bing.ChatProps{
Model: model,
Message: props.Message,
}, hook)

case globals.PalmChannelType:
return palm2.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&palm2.ChatProps{
Model: model,
Message: props.Message,
}, hook)

case globals.SparkdeskChannelType:
return sparkdesk.NewChatInstance(conf, model).CreateStreamChatRequest(&sparkdesk.ChatProps{
Model: model,
Message: props.Message,
Token: props.MaxTokens,
Temperature: props.Temperature,
TopK: props.TopK,
Tools: props.Tools,
Buffer: props.Buffer,
}, hook)

case globals.ChatGLMChannelType:
return zhipuai.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&zhipuai.ChatProps{
Model: model,
Message: props.Message,
Temperature: props.Temperature,
TopP: props.TopP,
}, hook)

case globals.QwenChannelType:
return dashscope.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&dashscope.ChatProps{
Model: model,
Message: props.Message,
Token: props.MaxTokens,
Temperature: props.Temperature,
TopP: props.TopP,
TopK: props.TopK,
RepetitionPenalty: props.RepetitionPenalty,
}, hook)

case globals.HunyuanChannelType:
return hunyuan.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&hunyuan.ChatProps{
Model: model,
Message: props.Message,
Temperature: props.Temperature,
TopP: props.TopP,
}, hook)

case globals.BaichuanChannelType:
return baichuan.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&baichuan.ChatProps{
Model: model,
Message: props.Message,
TopP: props.TopP,
TopK: props.TopK,
Temperature: props.Temperature,
}, hook)

case globals.SkylarkChannelType:
return skylark.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&skylark.ChatProps{
Model: model,
Message: props.Message,
Token: props.MaxTokens,
TopP: props.TopP,
TopK: props.TopK,
Temperature: props.Temperature,
FrequencyPenalty: props.FrequencyPenalty,
PresencePenalty: props.PresencePenalty,
RepeatPenalty: props.RepetitionPenalty,
Tools: props.Tools,
}, hook)

case globals.ZhinaoChannelType:
return zhinao.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&zhinao.ChatProps{
Model: model,
Message: props.Message,
Token: props.MaxTokens,
TopP: props.TopP,
TopK: props.TopK,
Temperature: props.Temperature,
RepetitionPenalty: props.RepetitionPenalty,
}, hook)

case globals.MidjourneyChannelType:
return midjourney.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&midjourney.ChatProps{
Model: model,
Messages: props.Message,
}, hook)

default:
return fmt.Errorf("unknown channel type %s (model: %s)", conf.GetType(), props.Model)
factoryType := conf.GetType()
if factory, ok := channelFactories[factoryType]; ok {
return factory(conf).CreateStreamChatRequest(props, hook)
}

return fmt.Errorf("unknown channel type %s (channel #%d)", conf.GetType(), conf.GetId())
}
38 changes: 13 additions & 25 deletions adapter/azure/chat.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,15 @@
package azure

import (
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"errors"
"fmt"
"strings"
)

type ChatProps struct {
Model string
Message []globals.Message
Token *int
PresencePenalty *float32
FrequencyPenalty *float32
Temperature *float32
TopP *float32
Tools *globals.FunctionTools
ToolChoice *interface{}
Buffer utils.Buffer
}

func (c *ChatInstance) GetChatEndpoint(props *ChatProps) string {
func (c *ChatInstance) GetChatEndpoint(props *adaptercommon.ChatProps) string {
model := strings.ReplaceAll(props.Model, ".", "")
if props.Model == globals.GPT3TurboInstruct {
return fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", c.GetResource(), model, c.GetEndpoint())
Expand All @@ -37,27 +25,27 @@ func (c *ChatInstance) GetCompletionPrompt(messages []globals.Message) string {
return result
}

func (c *ChatInstance) GetLatestPrompt(props *ChatProps) string {
func (c *ChatInstance) GetLatestPrompt(props *adaptercommon.ChatProps) string {
if len(props.Message) == 0 {
return ""
}

return props.Message[len(props.Message)-1].Content
}

func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
func (c *ChatInstance) GetChatBody(props *adaptercommon.ChatProps, stream bool) interface{} {
if props.Model == globals.GPT3TurboInstruct {
// for completions
return CompletionRequest{
Prompt: c.GetCompletionPrompt(props.Message),
MaxToken: props.Token,
MaxToken: props.MaxTokens,
Stream: stream,
}
}

return ChatRequest{
Messages: formatMessages(props),
MaxToken: props.Token,
MaxToken: props.MaxTokens,
Stream: stream,
PresencePenalty: props.PresencePenalty,
FrequencyPenalty: props.FrequencyPenalty,
Expand All @@ -68,8 +56,8 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
}
}

// CreateChatRequest is the native http request body for chatgpt
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
// CreateChatRequest is the native http request body for openai
func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) {
if globals.IsOpenAIDalleModel(props.Model) {
return c.CreateImage(props)
}
Expand All @@ -81,20 +69,20 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
)

if err != nil || res == nil {
return "", fmt.Errorf("chatgpt error: %s", err.Error())
return "", fmt.Errorf("openai error: %s", err.Error())
}

data := utils.MapToStruct[ChatResponse](res)
if data == nil {
return "", fmt.Errorf("chatgpt error: cannot parse response")
return "", fmt.Errorf("openai error: cannot parse response")
} else if data.Error.Message != "" {
return "", fmt.Errorf("chatgpt error: %s", data.Error.Message)
return "", fmt.Errorf("openai error: %s", data.Error.Message)
}
return data.Choices[0].Message.Content, nil
}

// CreateStreamChatRequest is the stream response body for chatgpt
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
// CreateStreamChatRequest is the stream response body for openai
func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error {
if globals.IsOpenAIDalleModel(props.Model) {
if url, err := c.CreateImage(props); err != nil {
return err
Expand Down
9 changes: 5 additions & 4 deletions adapter/azure/image.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package azure

import (
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"fmt"
Expand Down Expand Up @@ -32,21 +33,21 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
N: 1,
})
if err != nil || res == nil {
return "", fmt.Errorf("chatgpt error: %s", err.Error())
return "", fmt.Errorf("openai error: %s", err.Error())
}

data := utils.MapToStruct[ImageResponse](res)
if data == nil {
return "", fmt.Errorf("chatgpt error: cannot parse response")
return "", fmt.Errorf("openai error: cannot parse response")
} else if data.Error.Message != "" {
return "", fmt.Errorf("chatgpt error: %s", data.Error.Message)
return "", fmt.Errorf("openai error: %s", data.Error.Message)
}

return data.Data[0].Url, nil
}

// CreateImage will create a dalle image from prompt, return markdown of image
func (c *ChatInstance) CreateImage(props *ChatProps) (string, error) {
func (c *ChatInstance) CreateImage(props *adaptercommon.ChatProps) (string, error) {
url, err := c.CreateImageRequest(ImageProps{
Model: props.Model,
Prompt: c.GetLatestPrompt(props),
Expand Down
9 changes: 5 additions & 4 deletions adapter/azure/processor.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package azure

import (
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"errors"
"fmt"
"regexp"
)

func formatMessages(props *ChatProps) interface{} {
func formatMessages(props *adaptercommon.ChatProps) interface{} {
if globals.IsVisionModel(props.Model) {
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
if message.Role == globals.User {
Expand Down Expand Up @@ -120,7 +121,7 @@ func (c *ChatInstance) ProcessLine(data string, isCompletionType bool) (*globals
}, nil
}

globals.Warn(fmt.Sprintf("chatgpt error: cannot parse completion response: %s", data))
globals.Warn(fmt.Sprintf("openai error: cannot parse completion response: %s", data))
return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse completion response")
}

Expand All @@ -129,9 +130,9 @@ func (c *ChatInstance) ProcessLine(data string, isCompletionType bool) (*globals
}

if form := processChatErrorResponse(data); form != nil {
return &globals.Chunk{Content: ""}, errors.New(fmt.Sprintf("chatgpt error: %s (type: %s)", form.Error.Message, form.Error.Type))
return &globals.Chunk{Content: ""}, errors.New(fmt.Sprintf("openai error: %s (type: %s)", form.Error.Message, form.Error.Type))
}

globals.Warn(fmt.Sprintf("chatgpt error: cannot parse chat completion response: %s", data))
globals.Warn(fmt.Sprintf("openai error: cannot parse chat completion response: %s", data))
return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse chat completion response")
}
3 changes: 2 additions & 1 deletion adapter/azure/struct.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package azure

import (
factory "chat/adapter/common"
"chat/globals"
)

Expand Down Expand Up @@ -42,7 +43,7 @@ func NewChatInstance(endpoint, apiKey string, resource string) *ChatInstance {
}
}

func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance {
func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory {
param := conf.SplitRandomSecret(2)
return NewChatInstance(
conf.GetEndpoint(),
Expand Down
Loading

0 comments on commit 422da58

Please sign in to comment.