Skip to content

Commit

Permalink
[Bug][Transforms-V2] Fix LLM transform can not parse boolean value ty…
Browse files Browse the repository at this point in the history
…pe (apache#7620)
  • Loading branch information
hawk9821 authored Sep 10, 2024
1 parent 6c7bb04 commit 6e1f207
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ public void testLLMWithOpenAI(TestContainer container)
Assertions.assertEquals(0, execResult.getExitCode());
}

@TestTemplate
public void testLLMWithOpenAIBoolean(TestContainer container)
throws IOException, InterruptedException {
Container.ExecResult execResult =
container.executeJob("/llm_openai_transform_boolean.conf");
Assertions.assertEquals(0, execResult.getExitCode());
}

@TestTemplate
public void testLLMWithCustomModel(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
######
###### This config file is a demonstration of streaming processing in seatunnel config
######

env {
job.mode = "BATCH"
}

source {
FakeSource {
row.num = 5
schema = {
fields {
id = "int"
name = "string"
}
}
rows = [
{fields = [1, "Jia Fan"], kind = INSERT}
{fields = [2, "Hailin Wang"], kind = INSERT}
{fields = [3, "Tomas"], kind = INSERT}
{fields = [4, "Eric"], kind = INSERT}
{fields = [5, "Guangdong Liu"], kind = INSERT}
]
result_table_name = "fake"
}
}

transform {
LLM {
source_table_name = "fake"
model_provider = OPENAI
model = gpt-4o-mini
api_key = sk-xxx
prompt = "Determine whether someone is Chinese or American by their name"
output_data_type = boolean
openai.api_path = "http://mockserver:1080/v2/chat/completions"
result_table_name = "llm_output"
}
}

sink {
Assert {
source_table_name = "llm_output"
rules =
{
field_rules = [
{
field_name = llm_output
field_type = boolean
field_value = [
{
rule_type = NOT_NULL
}
]
}
]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,41 @@
"Content-Type": "application/json"
}
}
},
{
"httpRequest": {
"method": "POST",
"path": "/v2/chat/completions"
},
"httpResponse": {
"body": {
"id": "chatcmpl-9s4hoBNGV0d9Mudkhvgzg64DAWPnx",
"object": "chat.completion",
"created": 1722674828,
"model": "gpt-4o-mini",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "[True]"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 107,
"completion_tokens": 3,
"total_tokens": 110
},
"system_fingerprint": "fp_0f03d4f0ee",
"code": 0,
"msg": "ok"
},
"headers": {
"Content-Type": "application/json"
}
}
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,8 @@ public List<String> inference(List<SeaTunnelRow> rows) throws IOException {

protected abstract List<String> chatWithModel(String promptWithLimit, String rowsJson)
throws IOException;

protected String convertData(String data) {
return outputType == SqlType.BOOLEAN ? data.toLowerCase() : data;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ protected List<String> chatWithModel(String prompt, String data) throws IOExcept

JsonNode result = OBJECT_MAPPER.readTree(responseStr);
String resultData = result.get("choices").get(0).get("message").get("content").asText();
return OBJECT_MAPPER.readValue(resultData, new TypeReference<List<String>>() {});
return OBJECT_MAPPER.readValue(
convertData(resultData), new TypeReference<List<String>>() {});
}

@VisibleForTesting
Expand Down

0 comments on commit 6e1f207

Please sign in to comment.