Skip to content

Commit

Permalink
feat: add gpt support by directly request url
Browse files Browse the repository at this point in the history
  • Loading branch information
Leizhenpeng committed Feb 7, 2023
1 parent f4a0bc3 commit 09323dc
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 52 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ go.work
.vscode
.s

./code/feishu_config.yaml
config.yaml



5 changes: 5 additions & 0 deletions code/config.example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
APP_ID: cli_axxx
APP_SECRET: xxx
APP_ENCRYPT_KEY: xxxx
APP_VERIFICATION_TOKEN: xxx
OPENAI_KEY: XXX
10 changes: 6 additions & 4 deletions code/handlers/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ func sendMsg(ctx context.Context, msg string, chatId *string) {
msg = strings.Trim(msg, "\n")
msg = strings.Trim(msg, "\r")
msg = strings.Trim(msg, "\t")
//只保留中文和英文
regex := regexp.MustCompile("i[^a-zA-Z0-9\u4e00-\u9fa5]")
msg = regex.ReplaceAllString(msg, "")
// 去除空行 以及空行前的空格
regex := regexp.MustCompile(`\n[\s| ]*\r`)
msg = regex.ReplaceAllString(msg, "\n")
//换行符转义
msg = strings.ReplaceAll(msg, "\n", "\\n")
fmt.Println("sendMsg", msg, chatId)
client := initialization.GetLarkClient()
content := larkim.NewTextMsgBuilder().
Text(msg).
TextLine(msg).
Build()
fmt.Println("content", content)

Expand Down
48 changes: 29 additions & 19 deletions code/handlers/personal.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,55 @@ package handlers
import (
"context"
"fmt"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
"start-feishubot/services"
)

type PersonalMessageHandler struct {
cache services.UserCacheInterface
userCache services.UserCacheInterface
msgCache services.MsgCacheInterface
}

func (p PersonalMessageHandler) handle(ctx context.Context, event *larkim.P2MessageReceiveV1) error {

fmt.Println(larkcore.Prettify(event))
content := event.Event.Message.Content
q := parseContent(*content)
fmt.Println("q", q)
//sender := event.Event.Sender
//openId := sender.SenderId.OpenId
//cacheContent := p.cache.Get(*openId)
qEnd := q
//if cacheContent != "" {
// qEnd = cacheContent + q
//}
msgId := event.Event.Message.MessageId
if p.msgCache.IfProcessed(*msgId) {
fmt.Println("msgId", *msgId, "processed")
return nil
}
qParsed := parseContent(*content)
fmt.Println("qParsed", qParsed)
sender := event.Event.Sender
openId := sender.SenderId.OpenId
cacheContent := p.userCache.Get(*openId)
qEnd := qParsed
if cacheContent != "" {
qEnd = cacheContent + qParsed
}
fmt.Println("qEnd", qEnd)
ok := true
reply, ok := services.GetAnswer(qEnd)
fmt.Println("reply", reply, ok)
completions, err := services.Completions(qEnd)
p.msgCache.TagProcessed(*msgId)
if err != nil {
return err
}
if len(completions) == 0 {
ok = false
}
if ok {
sendMsg(ctx, reply, event.Event.Message.ChatId)
//p.cache.Set(*openId, q, "nihao")
return nil
p.userCache.Set(*openId, qParsed, completions)
sendMsg(ctx, completions, event.Event.Message.ChatId)
}

return nil

}

var _ MessageHandlerInterface = (*PersonalMessageHandler)(nil)

func NewPersonalMessageHandler() MessageHandlerInterface {
return &PersonalMessageHandler{
cache: services.GetUserCache(),
userCache: services.GetUserCache(),
msgCache: services.GetMsgCache(),
}
}
2 changes: 1 addition & 1 deletion code/initialization/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

func LoadConfig() {
viper.SetConfigFile("./feishu_config.yaml")
viper.SetConfigFile("./config.yaml")
err := viper.ReadInConfig()
if err != nil {
panic(fmt.Errorf("Fatal error config file: %s \n", err))
Expand Down
105 changes: 84 additions & 21 deletions code/services/gpt3.go
Original file line number Diff line number Diff line change
@@ -1,39 +1,102 @@
package services

import (
"context"
"github.com/PullRequestInc/go-gpt3"
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/spf13/viper"
"io/ioutil"
"log"
"net/http"
)

const (
BASEURL = "https://api.openai.com/v1/"
maxTokens = 2000
temperature = 0.7
engine = gpt3.TextDavinci003Engine
engine = "text-davinci-003"
)

func GetAnswer(question string) (reply string, ok bool) {
client := gpt3.NewClient(viper.GetString("OPENAI_KEY"))

ok = false
reply = ""
ctx := context.Background()
resp, err := client.CompletionWithEngine(ctx, engine, gpt3.CompletionRequest{
Prompt: []string{
question,
},
MaxTokens: gpt3.IntPtr(maxTokens),
Temperature: gpt3.Float32Ptr(temperature),
})
// ChatGPTResponseBody 请求体
type ChatGPTResponseBody struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []ChoiceItem `json:"choices"`
Usage map[string]interface{} `json:"usage"`
}

type ChoiceItem struct {
Text string `json:"text"`
Index int `json:"index"`
Logprobs int `json:"logprobs"`
FinishReason string `json:"finish_reason"`
}

// ChatGPTRequestBody 响应体
type ChatGPTRequestBody struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
Temperature float32 `json:"temperature"`
TopP int `json:"top_p"`
FrequencyPenalty int `json:"frequency_penalty"`
PresencePenalty int `json:"presence_penalty"`
}

func Completions(msg string) (string, error) {
requestBody := ChatGPTRequestBody{
Model: engine,
Prompt: msg,
MaxTokens: maxTokens,
Temperature: temperature,
TopP: 1,
FrequencyPenalty: 0,
PresencePenalty: 0,
}
requestData, err := json.Marshal(requestBody)

if err != nil {
return "", err
}
log.Printf("request gtp json string : %v", string(requestData))
req, err := http.NewRequest("POST", BASEURL+"completions", bytes.NewBuffer(requestData))
if err != nil {
log.Fatalln(err)
return "", err
}
reply = resp.Choices[0].Text
if reply != "" {
ok = true

apiKey := viper.GetString("OPENAI_KEY")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{}
response, err := client.Do(req)
if err != nil {
return "", err
}
defer response.Body.Close()
if response.StatusCode != 200 {
return "", errors.New(fmt.Sprintf("gtp api status code not equals 200,code is %d", response.StatusCode))
}
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return "", err
}

gptResponseBody := &ChatGPTResponseBody{}
log.Println(string(body))
err = json.Unmarshal(body, gptResponseBody)
if err != nil {
return "", err
}

var reply string
if len(gptResponseBody.Choices) > 0 {
reply = gptResponseBody.Choices[0].Text
}
return reply, ok
log.Printf("gpt response text: %s \n", reply)
return reply, nil
}

func FormatQuestion(question string) string {
Expand Down
40 changes: 40 additions & 0 deletions code/services/msgCache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package services

import (
"github.com/patrickmn/go-cache"
"time"
)

type MsgService struct {
cache *cache.Cache
}

var msgService *MsgService

func (u MsgService) IfProcessed(msgId string) bool {
get, b := u.cache.Get(msgId)
if !b {
return false
}
return get.(bool)
}
func (u MsgService) TagProcessed(msgId string) {
u.cache.Set(msgId, true, time.Minute*5)
}

func (u MsgService) Clear(userId string) bool {
u.cache.Delete(userId)
return true
}

type MsgCacheInterface interface {
IfProcessed(msg string) bool
TagProcessed(msg string)
}

func GetMsgCache() MsgCacheInterface {
if msgService == nil {
msgService = &MsgService{cache: cache.New(10*time.Minute, 10*time.Minute)}
}
return msgService
}
30 changes: 26 additions & 4 deletions code/services/userCache.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,38 @@ type UserService struct {
var userServices *UserService

func (u UserService) Get(userId string) string {
get, b := u.cache.Get(userId)
if !b {
// 获取用户的会话上下文
sessionContext, ok := u.cache.Get(userId)
if !ok {
return ""
}
return get.(string)
//list to string
list := sessionContext.([]string)
var result string
for _, v := range list {
result += v + "------------------------\n"
}
return result
}

func (u UserService) Set(userId string, question, reply string) {
// 列表,最多保存4个
//如果满了,删除最早的一个
//如果没有满,直接添加
listOut := make([]string, 4)
value := question + "\n" + reply
u.cache.Set(userId, value, time.Minute*5)

raw, ok := u.cache.Get(userId)
if ok {
listOut = raw.([]string)
if len(listOut) == 4 {
listOut = listOut[1:]
}
listOut = append(listOut, value)
} else {
listOut = append(listOut, value)
}
u.cache.Set(userId, listOut, time.Minute*5)
}

func (u UserService) Clear(userId string) bool {
Expand Down
7 changes: 5 additions & 2 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@

### 相关阅读

- [go-value-cache]()

- [在Go语言项目中使用Zap日志库](https://www.liwenzhou.com/posts/Go/zap/)
-

- [飞书 User_ID、Open_ID 与 Union_ID 区别](https://www.feishu.cn/hc/zh-CN/articles/794300086214)
-

- [飞书重复接受到消息](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/message/events/receive)

0 comments on commit 09323dc

Please sign in to comment.