Skip to content

Commit

Permalink
api: enable tool streaming (ollama#7836)
Browse files Browse the repository at this point in the history
  • Loading branch information
ParthSareen authored Nov 27, 2024
1 parent e3936d4 commit ce7455a
Show file tree
Hide file tree
Showing 4 changed files with 289 additions and 13 deletions.
13 changes: 9 additions & 4 deletions openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b))
}

func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
for i, tc := range r.Message.ToolCalls {
func toToolCalls(tc []api.ToolCall) []ToolCall {
toolCalls := make([]ToolCall, len(tc))
for i, tc := range tc {
toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name
Expand All @@ -215,7 +215,11 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {

toolCalls[i].Function.Arguments = string(args)
}
return toolCalls
}

func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletion{
Id: id,
Object: "chat.completion",
Expand Down Expand Up @@ -244,6 +248,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}

func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletionChunk{
Id: id,
Object: "chat.completion.chunk",
Expand All @@ -252,7 +257,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{{
Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content},
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
Expand Down
1 change: 1 addition & 0 deletions server/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func TestExecuteWithTools(t *testing.T) {
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
Expand Down
32 changes: 31 additions & 1 deletion server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,7 @@ func (s *Server) ChatHandler(c *gin.Context) {

prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil {
slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
Expand All @@ -1467,6 +1468,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch := make(chan any)
go func() {
defer close(ch)
var sb strings.Builder
var hasToolCalls bool
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Expand All @@ -1492,7 +1495,34 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}

ch <- res
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
// however this was a simple change for now without reworking streaming logic of this (and other)
// handlers
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
ch <- res
return
}

// Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
sb.Reset()
hasToolCalls = true
ch <- res
return
}

if r.Done {
// Send any remaining content if no tool calls were detected
if !hasToolCalls {
res.Message.Content = sb.String()
}
ch <- res
}
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
Expand Down
Loading

0 comments on commit ce7455a

Please sign in to comment.