Skip to content

Commit

Permalink
feat: add gemini pro vision model
Browse files Browse the repository at this point in the history
  • Loading branch information
zhu327 committed Dec 29, 2023
1 parent c60e875 commit 3baa29a
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 8 deletions.
19 changes: 16 additions & 3 deletions api/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func ChatProxyHandler(c *gin.Context) {
return
}

var req *adapter.ChatCompletionRequest
var req = &adapter.ChatCompletionRequest{}
// Bind the JSON data from the request to the struct
if err := c.ShouldBindJSON(req); err != nil {
c.JSON(http.StatusBadRequest, openai.APIError{
Expand All @@ -82,7 +82,12 @@ func ChatProxyHandler(c *gin.Context) {
}
defer client.Close()

gemini := adapter.NewGeminiProAdapter(client)
var gemini adapter.GenaiModelAdapter
if req.Model == openai.GPT4VisionPreview {
gemini = adapter.NewGeminiProVisionAdapter(client)
} else {
gemini = adapter.NewGeminiProAdapter(client)
}

if !req.Stream {
resp, err := gemini.GenerateContent(ctx, req)
Expand All @@ -99,7 +104,15 @@ func ChatProxyHandler(c *gin.Context) {
return
}

dataChan := gemini.GenerateStreamContent(ctx, req)
dataChan, err := gemini.GenerateStreamContent(ctx, req)
if err != nil {
log.Printf("genai generate content error %v\n", err)
c.JSON(http.StatusBadRequest, openai.APIError{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}

setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
Expand Down
88 changes: 83 additions & 5 deletions pkg/adapter/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ import (
)

const (
GeminiPro = "gemini-pro"
GeminiPro = "gemini-pro"
GeminiProVision = "gemini-pro-vision"

genaiRoleUser = "user"
genaiRoleModel = "model"
)

type GenaiModelAdapter interface {
GenerateContent(ctx context.Context, req *ChatCompletionRequest) (*openai.ChatCompletionResponse, error)
GenerateStreamContent(ctx context.Context, req *ChatCompletionRequest) <-chan string
GenerateStreamContent(ctx context.Context, req *ChatCompletionRequest) (<-chan string, error)
}

type GeminiProAdapter struct {
Expand All @@ -51,7 +52,6 @@ func (g *GeminiProAdapter) GenerateContent(
prompt := genai.Text(req.Messages[len(req.Messages)-1].Content)
genaiResp, err := cs.SendMessage(ctx, prompt)
if err != nil {
log.Printf("genai send message error %v\n", err)
return nil, errors.Wrap(err, "genai send message error")
}

Expand All @@ -62,7 +62,7 @@ func (g *GeminiProAdapter) GenerateContent(
func (g *GeminiProAdapter) GenerateStreamContent(
ctx context.Context,
req *ChatCompletionRequest,
) <-chan string {
) (<-chan string, error) {
model := g.client.GenerativeModel(GeminiPro)
setGenaiModelByOpenaiRequest(model, req)

Expand All @@ -75,7 +75,85 @@ func (g *GeminiProAdapter) GenerateStreamContent(
dataChan := make(chan string)
go handleStreamIter(iter, dataChan)

return dataChan
return dataChan, nil
}

type GeminiProVisionAdapter struct {
client *genai.Client
}

func NewGeminiProVisionAdapter(client *genai.Client) GenaiModelAdapter {
return &GeminiProVisionAdapter{
client: client,
}
}

func (g *GeminiProVisionAdapter) GenerateContent(
ctx context.Context,
req *ChatCompletionRequest,
) (*openai.ChatCompletionResponse, error) {
model := g.client.GenerativeModel(GeminiProVision)
setGenaiModelByOpenaiRequest(model, req)

// NOTE: use last message as prompt, gemini pro vision does not support context
// https://ai.google.dev/tutorials/go_quickstart#multi-turn-conversations-chat
prompt, err := g.openaiMessageToGenaiPrompt(req.Messages[len(req.Messages)-1])
if err != nil {
return nil, errors.Wrap(err, "genai generate prompt error")
}

genaiResp, err := model.GenerateContent(ctx, prompt...)
if err != nil {
return nil, errors.Wrap(err, "genai send message error")
}

openaiResp := genaiResponseToOpenaiResponse(genaiResp)
return &openaiResp, nil
}

func (*GeminiProVisionAdapter) openaiMessageToGenaiPrompt(msg ChatCompletionMessage) ([]genai.Part, error) {
parts, err := msg.MultiContent()
if err != nil {
return nil, err
}

prompt := make([]genai.Part, 0, len(parts))
for _, part := range parts {
switch part.Type {
case openai.ChatMessagePartTypeText:
prompt = append(prompt, genai.Text(part.Text))
case openai.ChatMessagePartTypeImageURL:
data, format, err := parseImageURL(part.ImageURL.URL)
if err != nil {
return nil, errors.Wrap(err, "parse image url error")
}

prompt = append(prompt, genai.ImageData(format, data))
}
}
return prompt, nil
}

func (g *GeminiProVisionAdapter) GenerateStreamContent(
ctx context.Context,
req *ChatCompletionRequest,
) (<-chan string, error) {
model := g.client.GenerativeModel(GeminiProVision)
setGenaiModelByOpenaiRequest(model, req)

// NOTE: use last message as prompt, gemini pro vision does not support context
// https://ai.google.dev/tutorials/go_quickstart#multi-turn-conversations-chat
prompt, err := g.openaiMessageToGenaiPrompt(req.Messages[len(req.Messages)-1])
if err != nil {
return nil, errors.Wrap(err, "genai generate prompt error")
}

iter := model.GenerateContentStream(ctx, prompt...)

dataChan := make(chan string)
go handleStreamIter(iter, dataChan)

return dataChan, nil
}

func handleStreamIter(iter *genai.GenerateContentResponseIterator, dataChan chan string) {
Expand Down
89 changes: 89 additions & 0 deletions pkg/adapter/image.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package adapter

import (
"encoding/base64"
"fmt"
"io"
"net/http"
"strings"
)

func parseImageURL(imageURL string) ([]byte, string, error) {
if strings.HasPrefix(imageURL, "data:image/") {
return decodeBase64Image(imageURL)
}
return getImageInfoFromURL(imageURL)
}

func decodeBase64Image(base64String string) ([]byte, string, error) {
// Remove the base64 prefix (e.g., "data:image/png;base64,") if present
base64String = strings.TrimPrefix(base64String, "data:image/")
index := strings.Index(base64String, ";base64,")
if index != -1 {
base64String = base64String[index+len(";base64,"):]
}

// Decode base64 string to byte slice
data, err := base64.StdEncoding.DecodeString(base64String)
if err != nil {
return nil, "", err
}

// get image format
format, err := getImageFormat(base64String)
if err != nil {
return nil, "", err
}

return data, format, nil
}

func getImageFormat(dataURI string) (string, error) {
// Find the index of "image/"
startIndex := strings.Index(dataURI, "image/")
if startIndex == -1 {
return "", fmt.Errorf("image format not found in data URI")
}

// Extract the substring between "image/" and ";"
startIndex += len("image/")
endIndex := strings.Index(dataURI[startIndex:], ";")
if endIndex == -1 {
return "", fmt.Errorf("image format not found in data URI")
}

return dataURI[startIndex : startIndex+endIndex], nil
}

func getImageInfoFromURL(url string) ([]byte, string, error) {
// Make an HTTP GET request to the URL
response, err := http.Get(url)
if err != nil {
return nil, "", err
}
defer response.Body.Close()

// Read the response body
imageData, err := io.ReadAll(response.Body)
if err != nil {
return nil, "", err
}

// Extract image format from the "Content-Type" header
contentType := response.Header.Get("Content-Type")
format, err := getImageFormatFromContentType(contentType)
if err != nil {
return nil, "", err
}

return imageData, format, nil
}

func getImageFormatFromContentType(contentType string) (string, error) {
// Extract image format from the "Content-Type" header
parts := strings.Split(contentType, "/")
if len(parts) != 2 {
return "", fmt.Errorf("invalid Content-Type header")
}
return parts[1], nil
}

0 comments on commit 3baa29a

Please sign in to comment.