Skip to content

Commit

Permalink
[Feature][Transform] Add embedding transform (apache#7534)
Browse files Browse the repository at this point in the history
  • Loading branch information
corgy-w authored Sep 3, 2024
1 parent 7d31e56 commit 3310cfc
Show file tree
Hide file tree
Showing 36 changed files with 3,163 additions and 86 deletions.
392 changes: 392 additions & 0 deletions docs/en/transform-v2/embedding.md

Large diffs are not rendered by default.

28 changes: 15 additions & 13 deletions docs/en/transform-v2/llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ more.

## Options

| name | type | required | default value |
|------------------|--------|----------|--------------------------------------------|
| model_provider | enum | yes | |
| output_data_type | enum | no | String |
| prompt | string | yes | |
| model | string | yes | |
| api_key | string | yes | |
| openai.api_path | string | no | https://api.openai.com/v1/chat/completions |
| name | type | required | default value |
|------------------|--------|----------|---------------|
| model_provider | enum | yes | |
| output_data_type | enum | no | String |
| prompt | string | yes | |
| model | string | yes | |
| api_key | string | yes | |
| api_path | string | no | |

### model_provider

Expand All @@ -36,7 +36,7 @@ The prompt to send to the LLM. This parameter defines how LLM will process and r

The data read from source is a table like this:

| name | age |
| name | age |
|---------------|-----|
| Jia Fan | 20 |
| Hailin Wang | 20 |
Expand All @@ -51,7 +51,7 @@ Determine whether someone is Chinese or American by their name

The result will be:

| name | age | llm_output |
| name | age | llm_output |
|---------------|-----|------------|
| Jia Fan | 20 | Chinese |
| Hailin Wang | 20 | Chinese |
Expand All @@ -61,16 +61,18 @@ The result will be:
### model

The model to use. Different model providers have different models. For example, the OpenAI model can be `gpt-4o-mini`.
If you use OpenAI model, please refer https://platform.openai.com/docs/models/model-endpoint-compatibility of `/v1/chat/completions` endpoint.
If you use OpenAI model, please refer https://platform.openai.com/docs/models/model-endpoint-compatibility
of `/v1/chat/completions` endpoint.

### api_key

The API key to use for the model provider.
If you use OpenAI model, please refer https://platform.openai.com/docs/api-reference/api-keys of how to get the API key.

### openai.api_path
### api_path

The API path to use for the OpenAI model provider. In most cases, you do not need to change this configuration. If you are using an API agent's service, you may need to configure it to the agent's API address.
The API path to use for the model provider. In most cases, you do not need to change this configuration. If you
are using an API agent's service, you may need to configure it to the agent's API address.

### common options [string]

Expand Down
382 changes: 382 additions & 0 deletions docs/zh/transform-v2/embedding.md

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions docs/zh/transform-v2/llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

## 属性

| 名称 | 类型 | 是否必须 | 默认值 |
|------------------|--------|------|--------------------------------------------|
| model_provider | enum | yes | |
| output_data_type | enum | no | String |
| prompt | string | yes | |
| model | string | yes | |
| api_key | string | yes | |
| openai.api_path | string | no | https://api.openai.com/v1/chat/completions |
| 名称 | 类型 | 是否必须 | 默认值 |
|------------------|--------|------|--------|
| model_provider | enum | yes | |
| output_data_type | enum | no | String |
| prompt | string | yes | |
| model | string | yes | |
| api_key | string | yes | |
| api_path | string | no | |

### model_provider

Expand All @@ -34,7 +34,7 @@ STRING,INT,BIGINT,DOUBLE,BOOLEAN.

从源读取的数据是这样的表格:

| name | age |
| name | age |
|---------------|-----|
| Jia Fan | 20 |
| Hailin Wang | 20 |
Expand All @@ -49,7 +49,7 @@ Determine whether someone is Chinese or American by their name

这将返回:

| name | age | llm_output |
| name | age | llm_output |
|---------------|-----|------------|
| Jia Fan | 20 | Chinese |
| Hailin Wang | 20 | Chinese |
Expand All @@ -66,9 +66,9 @@ Determine whether someone is Chinese or American by their name
用于模型提供者的 API 密钥。
如果使用 OpenAI 模型,请参考 https://platform.openai.com/docs/api-reference/api-keys 文档的如何获取 API 密钥。

### openai.api_path
### api_path

用于 OpenAI 模型提供者的 API 路径。在大多数情况下,您不需要更改此配置。如果使用 API 代理的服务,您可能需要将其配置为代理的 API 地址。
用于模型提供者的 API 路径。在大多数情况下,您不需要更改此配置。如果使用 API 代理的服务,您可能需要将其配置为代理的 API 地址。

### common options [string]

Expand Down
1 change: 1 addition & 0 deletions plugin-mapping.properties
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,5 @@ seatunnel.transform.Split = seatunnel-transforms-v2
seatunnel.transform.Copy = seatunnel-transforms-v2
seatunnel.transform.DynamicCompile = seatunnel-transforms-v2
seatunnel.transform.LLM = seatunnel-transforms-v2
seatunnel.transform.Embedding = seatunnel-transforms-v2

Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ private int getBytesForValue(Object v) {
size += getBytesForValue(entry.getKey()) + getBytesForValue(entry.getValue());
}
return size;
case "HeapByteBuffer":
case "ByteBuffer":
return ((ByteBuffer) v).capacity();
case "SeaTunnelRow":
int rowSize = 0;
SeaTunnelRow row = (SeaTunnelRow) v;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import java.io.IOException;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -307,6 +308,13 @@ private Boolean checkType(Object value, SeaTunnelDataType<?> fieldType) {
return checkDecimalType(value, fieldType);
}

if (fieldType.getSqlType() == SqlType.FLOAT_VECTOR
|| fieldType.getSqlType() == SqlType.FLOAT16_VECTOR
|| fieldType.getSqlType() == SqlType.BFLOAT16_VECTOR
|| fieldType.getSqlType() == SqlType.BINARY_VECTOR) {
return value instanceof ByteBuffer;
}

return value.getClass().equals(fieldType.getTypeClass());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.TestResource;
import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestTemplate;
import org.testcontainers.containers.Container;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.output.Slf4jLogConsumer;
import org.testcontainers.containers.wait.strategy.HttpWaitStrategy;
import org.testcontainers.lifecycle.Startables;
import org.testcontainers.utility.DockerImageName;
import org.testcontainers.utility.DockerLoggerFactory;
import org.testcontainers.utility.MountableFile;

import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.util.Optional;
import java.util.stream.Stream;

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK, EngineType.FLINK},
disabledReason = "Currently SPARK and FLINK not support adapt")
public class TestEmbeddingIT extends TestSuiteBase implements TestResource {
private static final String TMP_DIR = "/tmp";
private GenericContainer<?> mockserverContainer;
private static final String IMAGE = "mockserver/mockserver:5.14.0";

@BeforeAll
@Override
public void startUp() {
Optional<URL> resource =
Optional.ofNullable(TestLLMIT.class.getResource("/mock-embedding.json"));
this.mockserverContainer =
new GenericContainer<>(DockerImageName.parse(IMAGE))
.withNetwork(NETWORK)
.withNetworkAliases("mockserver")
.withExposedPorts(1080)
.withCopyFileToContainer(
MountableFile.forHostPath(
new File(
resource.orElseThrow(
() ->
new IllegalArgumentException(
"Can not get config file of mockServer"))
.getPath())
.getAbsolutePath()),
TMP_DIR + "/mock-embedding.json")
.withEnv(
"MOCKSERVER_INITIALIZATION_JSON_PATH",
TMP_DIR + "/mock-embedding.json")
.withEnv("MOCKSERVER_LOG_LEVEL", "WARN")
.withLogConsumer(new Slf4jLogConsumer(DockerLoggerFactory.getLogger(IMAGE)))
.waitingFor(new HttpWaitStrategy().forPath("/").forStatusCode(404));
Startables.deepStart(Stream.of(mockserverContainer)).join();
}

@AfterAll
@Override
public void tearDown() throws Exception {
if (mockserverContainer != null) {
mockserverContainer.stop();
}
}

@TestTemplate
public void testEmbedding(TestContainer container) throws IOException, InterruptedException {
Container.ExecResult execResult = container.executeJob("/embedding_transform.conf");
Assertions.assertEquals(0, execResult.getExitCode());
}

@TestTemplate
public void testEmbeddingWithCustomModel(TestContainer container)
throws IOException, InterruptedException {
Container.ExecResult execResult = container.executeJob("/embedding_transform_custom.conf");
Assertions.assertEquals(0, execResult.getExitCode());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,11 @@ public void testLLMWithOpenAI(TestContainer container)
Container.ExecResult execResult = container.executeJob("/llm_openai_transform.conf");
Assertions.assertEquals(0, execResult.getExitCode());
}

@TestTemplate
public void testLLMWithCustomModel(TestContainer container)
throws IOException, InterruptedException {
Container.ExecResult execResult = container.executeJob("/llm_transform_custom.conf");
Assertions.assertEquals(0, execResult.getExitCode());
}
}
Loading

0 comments on commit 3310cfc

Please sign in to comment.