Skip to content

Commit

Permalink
fix: support ali's embedding model (songquanpeng#481, close songquanp…
Browse files Browse the repository at this point in the history
…eng#469)

* feat:支持阿里的 embedding 模型

* fix: add to model list

---------

Co-authored-by: JustSong <[email protected]>
Co-authored-by: JustSong <[email protected]>
  • Loading branch information
3 people authored Sep 3, 2023
1 parent bd6fe1e commit d0a0e87
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 20 deletions.
7 changes: 4 additions & 3 deletions common/model-ratio.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ var ModelRatio = map[string]float64{
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
"qwen-plus-v1": 0.5715, // Same as above
"SparkDesk": 0.8572, // TBD
"qwen-v1": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus-v1": 1, // ¥0.014 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
Expand Down
9 changes: 9 additions & 0 deletions controller/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,15 @@ func init() {
Root: "qwen-plus-v1",
Parent: nil,
},
{
Id: "text-embedding-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "text-embedding-v1",
Parent: nil,
},
{
Id: "SparkDesk",
Object: "model",
Expand Down
88 changes: 88 additions & 0 deletions controller/relay-ali.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,29 @@ type AliChatRequest struct {
Parameters AliParameters `json:"parameters,omitempty"`
}

type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}

type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}

type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}

type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
Expand All @@ -44,6 +67,7 @@ type AliError struct {
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}

type AliOutput struct {
Expand Down Expand Up @@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
}
}

func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}

func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}

err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}

if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}

fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
}

for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}

func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Expand Down
15 changes: 2 additions & 13 deletions controller/relay-baidu.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
}

func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
baiduEmbeddingRequest := BaiduEmbeddingRequest{
Input: nil,
return &BaiduEmbeddingRequest{
Input: request.ParseInput(),
}
switch request.Input.(type) {
case string:
baiduEmbeddingRequest.Input = []string{request.Input.(string)}
case []any:
for _, item := range request.Input.([]any) {
if str, ok := item.(string); ok {
baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str)
}
}
}
return &baiduEmbeddingRequest
}

func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
Expand Down
24 changes: 21 additions & 3 deletions controller/relay-text.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
if relayMode == RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
}
case APITypeAIProxyLibrary:
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
}
Expand Down Expand Up @@ -262,8 +265,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAli:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err := json.Marshal(aliRequest)
var jsonStr []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliEmbeddingRequest)
default:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
Expand Down Expand Up @@ -502,7 +513,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
return nil
} else {
err, usage := aliHandler(c, resp)
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
}
if err != nil {
return err
}
Expand Down
19 changes: 19 additions & 0 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ type GeneralOpenAIRequest struct {
Functions any `json:"functions,omitempty"`
}

func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}

type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Expand Down
2 changes: 1 addition & 1 deletion web/src/pages/Channel/EditChannel.js
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ const EditChannel = () => {
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
break;
case 17:
localModels = ['qwen-v1', 'qwen-plus-v1'];
localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1'];
break;
case 16:
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
Expand Down

0 comments on commit d0a0e87

Please sign in to comment.