Skip to content

Commit

Permalink
Fix MiniMax model function call implementation
Browse files Browse the repository at this point in the history
Implement function call capability for MiniMax model and add unit tests based on new tool classes.
Address most scenarios, but note limitations in complex English contexts
with multiple function calls. Weather query example: may stop
prematurely when querying multiple locations due to single-location
parameter limit. This behavior stems from model performance constraints.

Streaming function calling is not passing tests, will be address seperately.

Resolves spring-projects#1077

Implement function call capability for the Moonshot model. Include unit
tests to verify the new functionality. This feature addresses the
requirements outlined in issue spring-projects#1058.
fix: MiniMax function call

review
  • Loading branch information
mxsl-gr authored and markpollack committed Aug 23, 2024
1 parent 611c949 commit 0927bd1
Show file tree
Hide file tree
Showing 11 changed files with 279 additions and 275 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
*/
package org.springframework.ai.minimax;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
Expand All @@ -31,6 +26,11 @@
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* MiniMaxChatOptions represents the options for performing chat completion using the
* MiniMax API. It provides methods to set and retrieve various options like model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
.takeUntil(SSE_DONE_PREDICATE)
.filter(SSE_DONE_PREDICATE.negate())
.map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class))
.map(chunk -> {
.map(chunk -> {
if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) {
isInsideTool.set(true);
}
Expand All @@ -730,7 +730,7 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
.concatMapIterable(window -> {
Mono<ChatCompletionChunk> monoChunk = window.reduce(
new ChatCompletionChunk(null, null, null, null, null, null),
this.chunkMerger::merge);
(previous, current) -> this.chunkMerger.merge(previous, current));
return List.of(monoChunk);
})
.flatMap(mono -> mono);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@
*/
package org.springframework.ai.minimax.api;

import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionChunk;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionChunk.ChunkChoice;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionFinishReason;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.ChatCompletionFunction;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.Role;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.minimax.api.MiniMaxApi.LogProbs;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -29,14 +38,7 @@
*/
public class MiniMaxStreamFunctionCallingHelper {

/**
* Merge the previous and current ChatCompletionChunk into a single one.
* @param previous the previous ChatCompletionChunk
* @param current the current ChatCompletionChunk
* @return the merged ChatCompletionChunk
*/
public MiniMaxApi.ChatCompletionChunk merge(MiniMaxApi.ChatCompletionChunk previous,
MiniMaxApi.ChatCompletionChunk current) {
public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) {

if (previous == null) {
return current;
Expand All @@ -49,47 +51,38 @@ public MiniMaxApi.ChatCompletionChunk merge(MiniMaxApi.ChatCompletionChunk previ
: previous.systemFingerprint());
String object = (current.object() != null ? current.object() : previous.object());

MiniMaxApi.ChatCompletionChunk.ChunkChoice previousChoice0 = (CollectionUtils.isEmpty(previous.choices()) ? null
: previous.choices().get(0));
MiniMaxApi.ChatCompletionChunk.ChunkChoice currentChoice0 = (CollectionUtils.isEmpty(current.choices()) ? null
: current.choices().get(0));
ChunkChoice previousChoice0 = (CollectionUtils.isEmpty(previous.choices()) ? null : previous.choices().get(0));
ChunkChoice currentChoice0 = (CollectionUtils.isEmpty(current.choices()) ? null : current.choices().get(0));

MiniMaxApi.ChatCompletionChunk.ChunkChoice choice = merge(previousChoice0, currentChoice0);
List<MiniMaxApi.ChatCompletionChunk.ChunkChoice> chunkChoices = choice == null ? List.of() : List.of(choice);
return new MiniMaxApi.ChatCompletionChunk(id, chunkChoices, created, model, systemFingerprint, object);
ChunkChoice choice = merge(previousChoice0, currentChoice0);
List<ChunkChoice> chunkChoices = choice == null ? List.of() : List.of(choice);
return new ChatCompletionChunk(id, chunkChoices, created, model, systemFingerprint, object);
}

private MiniMaxApi.ChatCompletionChunk.ChunkChoice merge(MiniMaxApi.ChatCompletionChunk.ChunkChoice previous,
MiniMaxApi.ChatCompletionChunk.ChunkChoice current) {
private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
if (previous == null) {
return current;
}

MiniMaxApi.ChatCompletionFinishReason finishReason = (current.finishReason() != null ? current.finishReason()
ChatCompletionFinishReason finishReason = (current.finishReason() != null ? current.finishReason()
: previous.finishReason());
Integer index = (current.index() != null ? current.index() : previous.index());
LogProbs logprobs = (current.logprobs() != null ? current.logprobs() : previous.logprobs());

MiniMaxApi.ChatCompletionMessage message = merge(previous.delta(), current.delta());

MiniMaxApi.LogProbs logprobs = (current.logprobs() != null ? current.logprobs() : previous.logprobs());
return new MiniMaxApi.ChatCompletionChunk.ChunkChoice(finishReason, index, message, logprobs);
ChatCompletionMessage message = merge(previous.delta(), current.delta());
return new ChunkChoice(finishReason, index, message, logprobs);
}

private MiniMaxApi.ChatCompletionMessage merge(MiniMaxApi.ChatCompletionMessage previous,
MiniMaxApi.ChatCompletionMessage current) {
private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
String content = (current.content() != null ? current.content()
: (previous.content() != null) ? previous.content() : "");
MiniMaxApi.ChatCompletionMessage.Role role = (current.role() != null ? current.role() : previous.role());
role = (role != null ? role : MiniMaxApi.ChatCompletionMessage.Role.ASSISTANT); // default
// to
// ASSISTANT
// (if
// null
: "" + ((previous.content() != null) ? previous.content() : ""));
Role role = (current.role() != null ? current.role() : previous.role());
role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null
String name = (current.name() != null ? current.name() : previous.name());
String toolCallId = (current.toolCallId() != null ? current.toolCallId() : previous.toolCallId());

List<MiniMaxApi.ChatCompletionMessage.ToolCall> toolCalls = new ArrayList<>();
MiniMaxApi.ChatCompletionMessage.ToolCall lastPreviousTooCall = null;
List<ToolCall> toolCalls = new ArrayList<>();
ToolCall lastPreviousTooCall = null;
if (previous.toolCalls() != null) {
lastPreviousTooCall = previous.toolCalls().get(previous.toolCalls().size() - 1);
if (previous.toolCalls().size() > 1) {
Expand All @@ -101,58 +94,55 @@ private MiniMaxApi.ChatCompletionMessage merge(MiniMaxApi.ChatCompletionMessage
throw new IllegalStateException("Currently only one tool call is supported per message!");
}
var currentToolCall = current.toolCalls().iterator().next();
if (currentToolCall.id() != null) {
if (currentToolCall.id() == null
|| (lastPreviousTooCall != null && currentToolCall.id().equals(lastPreviousTooCall.id()))) {
toolCalls.add(merge(lastPreviousTooCall, currentToolCall));
}
else {
if (lastPreviousTooCall != null) {
toolCalls.add(lastPreviousTooCall);
}
toolCalls.add(currentToolCall);
}
else {
toolCalls.add(merge(lastPreviousTooCall, currentToolCall));
}
}
else {
if (lastPreviousTooCall != null) {
toolCalls.add(lastPreviousTooCall);
}
}
return new MiniMaxApi.ChatCompletionMessage(content, role, name, toolCallId, toolCalls);
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls);
}

private MiniMaxApi.ChatCompletionMessage.ToolCall merge(MiniMaxApi.ChatCompletionMessage.ToolCall previous,
MiniMaxApi.ChatCompletionMessage.ToolCall current) {
private ToolCall merge(ToolCall previous, ToolCall current) {
if (previous == null) {
return current;
}
String id = (current.id() != null ? current.id() : previous.id());
String type = (current.type() != null ? current.type() : previous.type());
MiniMaxApi.ChatCompletionMessage.ChatCompletionFunction function = merge(previous.function(),
current.function());
return new MiniMaxApi.ChatCompletionMessage.ToolCall(id, type, function);
ChatCompletionFunction function = merge(previous.function(), current.function());
return new ToolCall(id, type, function);
}

private MiniMaxApi.ChatCompletionMessage.ChatCompletionFunction merge(
MiniMaxApi.ChatCompletionMessage.ChatCompletionFunction previous,
MiniMaxApi.ChatCompletionMessage.ChatCompletionFunction current) {
private ChatCompletionFunction merge(ChatCompletionFunction previous, ChatCompletionFunction current) {
if (previous == null) {
return current;
}
String name = (current.name() != null ? current.name() : previous.name());
String name = (StringUtils.hasLength(current.name()) ? current.name() : previous.name());
StringBuilder arguments = new StringBuilder();
if (previous.arguments() != null) {
arguments.append(previous.arguments());
}
if (current.arguments() != null) {
arguments.append(current.arguments());
}
return new MiniMaxApi.ChatCompletionMessage.ChatCompletionFunction(name, arguments.toString());
return new ChatCompletionFunction(name, arguments.toString());
}

/**
* @param chatCompletion the ChatCompletionChunk to check
* @return true if the ChatCompletionChunk is a streaming tool function call.
*/
public boolean isStreamingToolFunctionCall(MiniMaxApi.ChatCompletionChunk chatCompletion) {
public boolean isStreamingToolFunctionCall(ChatCompletionChunk chatCompletion) {

if (chatCompletion == null || CollectionUtils.isEmpty(chatCompletion.choices())) {
return false;
Expand All @@ -170,7 +160,7 @@ public boolean isStreamingToolFunctionCall(MiniMaxApi.ChatCompletionChunk chatCo
* @return true if the ChatCompletionChunk is a streaming tool function call and it is
* the last one.
*/
public boolean isStreamingToolFunctionCallFinish(MiniMaxApi.ChatCompletionChunk chatCompletion) {
public boolean isStreamingToolFunctionCallFinish(ChatCompletionChunk chatCompletion) {

if (chatCompletion == null || CollectionUtils.isEmpty(chatCompletion.choices())) {
return false;
Expand All @@ -180,23 +170,7 @@ public boolean isStreamingToolFunctionCallFinish(MiniMaxApi.ChatCompletionChunk
if (choice == null || choice.delta() == null) {
return false;
}
return choice.finishReason() == MiniMaxApi.ChatCompletionFinishReason.TOOL_CALLS;
}

/**
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
* @param chunk the ChatCompletionChunk to convert
* @return the ChatCompletion
*/
public MiniMaxApi.ChatCompletion chunkToChatCompletion(MiniMaxApi.ChatCompletionChunk chunk) {
List<MiniMaxApi.ChatCompletion.Choice> choices = chunk.choices()
.stream()
.map(chunkChoice -> new MiniMaxApi.ChatCompletion.Choice(chunkChoice.finishReason(), chunkChoice.index(),
chunkChoice.delta(), chunkChoice.logprobs()))
.toList();

return new MiniMaxApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(),
chunk.systemFingerprint(), "chat.completion", null, null);
return choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ public class MiniMaxApiToolFunctionCallIT {
public void toolFunctionCall() {

// Step 1: send the conversation and available functions to the model
var message = new ChatCompletionMessage("What's the weather like in San Francisco?", Role.USER);
var message = new ChatCompletionMessage(
"What's the weather like in San Francisco? Return the temperature in Celsius.", Role.USER);

var functionTool = new MiniMaxApi.FunctionTool(Type.FUNCTION, new MiniMaxApi.FunctionTool.Function(
"Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """
Expand Down Expand Up @@ -126,7 +127,8 @@ public void toolFunctionCall() {

assertThat(chatCompletion2.getBody().choices().get(0).message().role()).isEqualTo(Role.ASSISTANT);
assertThat(chatCompletion2.getBody().choices().get(0).message().content()).contains("San Francisco")
.containsAnyOf("30.0°C", "30°C", "30.0°F", "30°F");
.containsAnyOf("30.0°C", "30°C", "30.0")
.containsAnyOf("°C", "Celsius");
}

private static <T> T fromJson(String json, Class<T> targetClass) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,8 @@ public MiniMaxChatModel miniMaxChatModel(MiniMaxConnectionProperties commonPrope
var miniMaxApi = miniMaxApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(),
chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler);

MiniMaxChatModel chatModel = new MiniMaxChatModel(miniMaxApi, chatProperties.getOptions(),
functionCallbackContext, retryTemplate);

if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
Map<String, FunctionCallback> toolFunctionCallbackMap = toolFunctionCallbacks.stream()
.collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b));
chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap);
}

return chatModel;
return new MiniMaxChatModel(miniMaxApi, chatProperties.getOptions(), functionCallbackContext,
toolFunctionCallbacks, retryTemplate);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.minimax.MiniMaxChatModel;
import org.springframework.ai.minimax.MiniMaxChatOptions;
Expand Down Expand Up @@ -53,11 +53,12 @@ public class FunctionCallbackInPromptIT {

@Test
void functionCallTest() {
contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6-chat").run(context -> {
contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> {

MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class);

UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");

var promptOptions = MiniMaxChatOptions.builder()
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
Expand All @@ -78,11 +79,12 @@ void functionCallTest() {
@Test
void streamingFunctionCallTest() {

contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6-chat").run(context -> {
contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> {

MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class);

UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");

var promptOptions = MiniMaxChatOptions.builder()
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.minimax.MiniMaxChatModel;
import org.springframework.ai.minimax.MiniMaxChatOptions;
Expand Down Expand Up @@ -57,14 +57,16 @@ class FunctionCallbackWithPlainFunctionBeanIT {
RestClientAutoConfiguration.class, MiniMaxAutoConfiguration.class))
.withUserConfiguration(Config.class);

// FIXME: multiple function calls may stop prematurely due to model performance
@Test
void functionCallTest() {
contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6-chat").run(context -> {
contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> {

MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class);

// Test weatherFunction
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");

ChatResponse response = chatModel.call(new Prompt(List.of(userMessage),
MiniMaxChatOptions.builder().withFunction("weatherFunction").build()));
Expand All @@ -86,12 +88,13 @@ void functionCallTest() {

@Test
void functionCallWithPortableFunctionCallingOptions() {
contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6-chat").run(context -> {
contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> {

MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class);

// Test weatherFunction
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");

PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
.withFunction("weatherFunction")
Expand All @@ -103,14 +106,16 @@ void functionCallWithPortableFunctionCallingOptions() {
});
}

// FIXME: multiple function calls may stop prematurely due to model performance
@Test
void streamFunctionCallTest() {
contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6-chat").run(context -> {
contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> {

MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class);

// Test weatherFunction
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");

Flux<ChatResponse> response = chatModel.stream(new Prompt(List.of(userMessage),
MiniMaxChatOptions.builder().withFunction("weatherFunction").build()));
Expand Down
Loading

0 comments on commit 0927bd1

Please sign in to comment.