Skip to content

Commit

Permalink
[FLINK-22653][python][table-planner-blink] Support StreamExecPythonOv…
Browse files Browse the repository at this point in the history
…erAggregate json serialization/deserialization

This closes apache#15937.
  • Loading branch information
HuangXingBo committed May 18, 2021
1 parent 06dec01 commit 4d33e85
Show file tree
Hide file tree
Showing 9 changed files with 2,558 additions and 3 deletions.
70 changes: 70 additions & 0 deletions flink-python/pyflink/table/tests/test_pandas_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,76 @@ def test_proc_time_over_rows_window_aggregate_function(self):
"+I[3, 2.0, 4]"])
os.remove(source_path)

def test_execute_over_aggregate_from_json_plan(self):
# create source file path
tmp_dir = self.tempdir
data = [
'1,1,2013-01-01 03:10:00',
'3,2,2013-01-01 03:10:00',
'2,1,2013-01-01 03:10:00',
'1,5,2013-01-01 03:10:00',
'1,8,2013-01-01 04:20:00',
'2,3,2013-01-01 03:30:00'
]
source_path = tmp_dir + '/test_execute_over_aggregate_from_json_plan.csv'
sink_path = tmp_dir + '/test_execute_over_aggregate_from_json_plan'
with open(source_path, 'w') as fd:
for ele in data:
fd.write(ele + '\n')

source_table = """
CREATE TABLE source_table (
a TINYINT,
b SMALLINT,
rowtime TIMESTAMP(3),
WATERMARK FOR rowtime AS rowtime - INTERVAL '60' MINUTE
) WITH (
'connector' = 'filesystem',
'path' = '%s',
'format' = 'csv'
)
""" % source_path
self.t_env.execute_sql(source_table)

self.t_env.execute_sql("""
CREATE TABLE sink_table (
a TINYINT,
b FLOAT,
c SMALLINT
) WITH (
'connector' = 'filesystem',
'path' = '%s',
'format' = 'csv'
)
""" % sink_path)

max_add_min_udaf = udaf(lambda a: a.max() + a.min(),
result_type=DataTypes.SMALLINT(),
func_type='pandas')
self.t_env.get_config().get_configuration().set_string(
"pipeline.time-characteristic", "EventTime")
self.t_env.create_temporary_system_function("mean_udaf", mean_udaf)
self.t_env.create_temporary_system_function("max_add_min_udaf", max_add_min_udaf)

json_plan = self.t_env._j_tenv.getJsonPlan("""
insert into sink_table
select a,
mean_udaf(b)
over (PARTITION BY a ORDER BY rowtime
ROWS BETWEEN 1 PRECEDING AND CURRENT ROW),
max_add_min_udaf(b)
over (PARTITION BY a ORDER BY rowtime
ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)
from source_table
""")
from py4j.java_gateway import get_method
get_method(self.t_env._j_tenv.executeJsonPlan(json_plan), "await")()

import glob
lines = [line.strip() for file in glob.glob(sink_path + '/*') for line in open(file, 'r')]
lines.sort()
self.assertEqual(lines, ['1,1.0,2', '1,3.0,6', '1,6.5,13', '2,1.0,2', '2,2.0,4', '3,2.0,4'])


@udaf(result_type=DataTypes.FLOAT(), func_type="pandas")
def mean_udaf(v):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
import org.apache.flink.table.types.logical.TimestampKind;
import org.apache.flink.table.types.logical.TimestampType;

import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;

import org.apache.calcite.rel.core.AggregateCall;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -54,6 +57,10 @@
import java.lang.reflect.InvocationTargetException;
import java.math.BigDecimal;
import java.util.Collections;
import java.util.List;

import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;

/** Stream {@link ExecNode} for python time-based over operator. */
public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
Expand All @@ -77,15 +84,34 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
"org.apache.flink.table.runtime.operators.python.aggregate.arrow.stream."
+ "StreamArrowPythonProcTimeBoundedRowsOperator";

public static final String FIELD_NAME_OVER_SPEC = "overSpec";

@JsonProperty(FIELD_NAME_OVER_SPEC)
private final OverSpec overSpec;

public StreamExecPythonOverAggregate(
OverSpec overSpec,
InputProperty inputProperty,
RowType outputType,
String description) {
super(Collections.singletonList(inputProperty), outputType, description);
this.overSpec = overSpec;
this(
overSpec,
getNewNodeId(),
Collections.singletonList(inputProperty),
outputType,
description);
}

@JsonCreator
public StreamExecPythonOverAggregate(
@JsonProperty(FIELD_NAME_OVER_SPEC) OverSpec overSpec,
@JsonProperty(FIELD_NAME_ID) int id,
@JsonProperty(FIELD_NAME_INPUT_PROPERTIES) List<InputProperty> inputProperties,
@JsonProperty(FIELD_NAME_OUTPUT_TYPE) RowType outputType,
@JsonProperty(FIELD_NAME_DESCRIPTION) String description) {
super(id, inputProperties, outputType, description);
checkArgument(inputProperties.size() == 1);
this.overSpec = checkNotNull(overSpec);
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ public class JsonSerdeCoverageTest {
"StreamExecWindowTableFunction",
"StreamExecGroupTableAggregate",
"StreamExecPythonGroupTableAggregate",
"StreamExecPythonOverAggregate",
"StreamExecSort",
"StreamExecMultipleInput",
"StreamExecValues");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.nodes.exec.stream;

import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.PandasAggregateFunction;
import org.apache.flink.table.planner.utils.StreamTableTestUtil;
import org.apache.flink.table.planner.utils.TableTestBase;

import org.junit.Before;
import org.junit.Test;

/** Test json serialization for over aggregate. */
public class PythonOverAggregateJsonPlanTest extends TableTestBase {
private StreamTableTestUtil util;
private TableEnvironment tEnv;

@Before
public void setup() {
util = streamTestUtil(TableConfig.getDefault());
tEnv = util.getTableEnv();
String srcTableDdl =
"CREATE TABLE MyTable (\n"
+ " a int,\n"
+ " b varchar,\n"
+ " c int not null,\n"
+ " rowtime timestamp(3),\n"
+ " proctime as PROCTIME(),\n"
+ " watermark for rowtime as rowtime"
+ ") with (\n"
+ " 'connector' = 'values',\n"
+ " 'bounded' = 'false')";
tEnv.executeSql(srcTableDdl);
tEnv.createTemporarySystemFunction("pyFunc", new PandasAggregateFunction());
}

@Test
public void testProcTimeBoundedPartitionedRangeOver() {
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " a bigint,\n"
+ " b bigint\n"
+ ") with (\n"
+ " 'connector' = 'values',\n"
+ " 'sink-insert-only' = 'false',\n"
+ " 'table-sink-class' = 'DEFAULT')";
tEnv.executeSql(sinkTableDdl);
String sql =
"insert into MySink SELECT a,\n"
+ " pyFunc(c, c) OVER (PARTITION BY a ORDER BY proctime\n"
+ " RANGE BETWEEN INTERVAL '2' HOUR PRECEDING AND CURRENT ROW)\n"
+ "FROM MyTable";
util.verifyJsonPlan(sql);
}

@Test
public void testProcTimeBoundedNonPartitionedRangeOver() {
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " a bigint,\n"
+ " b bigint\n"
+ ") with (\n"
+ " 'connector' = 'values',\n"
+ " 'sink-insert-only' = 'false',\n"
+ " 'table-sink-class' = 'DEFAULT')";
tEnv.executeSql(sinkTableDdl);
String sql =
"insert into MySink SELECT a,\n"
+ " pyFunc(c, c) OVER (ORDER BY proctime\n"
+ " RANGE BETWEEN INTERVAL '10' SECOND PRECEDING AND CURRENT ROW)\n"
+ " FROM MyTable";
util.verifyJsonPlan(sql);
}

@Test
public void testProcTimeUnboundedPartitionedRangeOver() {
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " a bigint,\n"
+ " b bigint\n"
+ ") with (\n"
+ " 'connector' = 'values',\n"
+ " 'sink-insert-only' = 'false',\n"
+ " 'table-sink-class' = 'DEFAULT')";
tEnv.executeSql(sinkTableDdl);
String sql =
"insert into MySink SELECT a,\n"
+ " pyFunc(c, c) OVER (PARTITION BY a ORDER BY proctime RANGE UNBOUNDED PRECEDING)\n"
+ "FROM MyTable";
util.verifyJsonPlan(sql);
}

@Test
public void testRowTimeBoundedPartitionedRowsOver() {
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " a bigint,\n"
+ " b bigint\n"
+ ") with (\n"
+ " 'connector' = 'values',\n"
+ " 'sink-insert-only' = 'false',\n"
+ " 'table-sink-class' = 'DEFAULT')";
tEnv.executeSql(sinkTableDdl);
String sql =
"insert into MySink SELECT a,\n"
+ " pyFunc(c, c) OVER (PARTITION BY a ORDER BY rowtime\n"
+ " ROWS BETWEEN 5 preceding AND CURRENT ROW)\n"
+ "FROM MyTable";
util.verifyJsonPlan(sql);
}

@Test
public void testProcTimeBoundedPartitionedRowsOverWithBuiltinProctime() {
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " a bigint,\n"
+ " b bigint\n"
+ ") with (\n"
+ " 'connector' = 'values',\n"
+ " 'sink-insert-only' = 'false',\n"
+ " 'table-sink-class' = 'DEFAULT')";
tEnv.executeSql(sinkTableDdl);
String sql =
"insert into MySink SELECT a, "
+ " pyFunc(c, c) OVER ("
+ " PARTITION BY a ORDER BY proctime() ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) "
+ "FROM MyTable";
util.verifyJsonPlan(sql);
}
}
Loading

0 comments on commit 4d33e85

Please sign in to comment.