From ef46460312b34284428398266db6f4cedf03052e Mon Sep 17 00:00:00 2001 From: Wei Zhong Date: Fri, 17 Jul 2020 16:58:39 +0800 Subject: [PATCH] [FLINK-18463][python] Make the "input_types" parameter of the Python UDF/UDTF decorator optional. This closes #12921. --- docs/dev/table/python/python_udfs.md | 21 ++-- docs/dev/table/python/python_udfs.zh.md | 21 ++-- .../table/python/vectorized_python_udfs.md | 2 +- .../table/python/vectorized_python_udfs.zh.md | 2 +- .../pyflink/table/table_environment.py | 21 ++-- .../pyflink/table/tests/test_dependency.py | 23 ++-- .../pyflink/table/tests/test_pandas_udf.py | 59 ++++----- flink-python/pyflink/table/tests/test_udf.py | 113 ++++++++---------- flink-python/pyflink/table/tests/test_udtf.py | 10 +- flink-python/pyflink/table/udf.py | 41 ++++--- .../python/PythonScalarFunction.java | 6 +- .../functions/python/PythonTableFunction.java | 6 +- 12 files changed, 150 insertions(+), 175 deletions(-) diff --git a/docs/dev/table/python/python_udfs.md b/docs/dev/table/python/python_udfs.md index 98ce5f064f2af..02ed376d33672 100644 --- a/docs/dev/table/python/python_udfs.md +++ b/docs/dev/table/python/python_udfs.md @@ -52,7 +52,7 @@ table_env = BatchTableEnvironment.create(env) table_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", '80m') # register the Python function -table_env.register_function("hash_code", udf(HashCode(), DataTypes.BIGINT(), DataTypes.BIGINT())) +table_env.register_function("hash_code", udf(HashCode(), result_type=DataTypes.BIGINT())) # use the Python function in Python Table API my_table.select("string, bigint, bigint.hash_code(), hash_code(bigint)") @@ -110,29 +110,28 @@ class Add(ScalarFunction): def eval(self, i, j): return i + j -add = udf(Add(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT()) +add = udf(Add(), result_type=DataTypes.BIGINT()) # option 2: Python function -@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT()) +@udf(result_type=DataTypes.BIGINT()) def add(i, j): return i + j # option 3: lambda function -add = udf(lambda i, j: i + j, [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT()) +add = udf(lambda i, j: i + j, result_type=DataTypes.BIGINT()) # option 4: callable function class CallableAdd(object): def __call__(self, i, j): return i + j -add = udf(CallableAdd(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT()) +add = udf(CallableAdd(), result_type=DataTypes.BIGINT()) # option 5: partial function def partial_add(i, j, k): return i + j + k -add = udf(functools.partial(partial_add, k=1), [DataTypes.BIGINT(), DataTypes.BIGINT()], - DataTypes.BIGINT()) +add = udf(functools.partial(partial_add, k=1), result_type=DataTypes.BIGINT()) # register the Python function table_env.register_function("add", add) @@ -163,7 +162,7 @@ my_table = ... # type: Table, table schema: [a: String] table_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", '80m') # register the Python Table Function -table_env.register_function("split", udtf(Split(), DataTypes.STRING(), [DataTypes.STRING(), DataTypes.INT()])) +table_env.register_function("split", udtf(Split(), result_types=[DataTypes.STRING(), DataTypes.INT()])) # use the Python Table Function in Python Table API my_table.join_lateral("split(a) as (word, length)") @@ -231,18 +230,18 @@ Like Python scalar functions, you can use the above five ways to define Python T {% highlight python %} # option 1: generator function -@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT()) +@udtf(result_types=DataTypes.BIGINT()) def generator_func(x): yield 1 yield 2 # option 2: return iterator -@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT()) +@udtf(result_types=DataTypes.BIGINT()) def iterator_func(x): return range(5) # option 3: return iterable -@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT()) +@udtf(result_types=DataTypes.BIGINT()) def iterable_func(x): result = [1, 2, 3] return result diff --git a/docs/dev/table/python/python_udfs.zh.md b/docs/dev/table/python/python_udfs.zh.md index 1d83da55134ea..c1618fb095726 100644 --- a/docs/dev/table/python/python_udfs.zh.md +++ b/docs/dev/table/python/python_udfs.zh.md @@ -52,7 +52,7 @@ table_env = BatchTableEnvironment.create(env) table_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", '80m') # register the Python function -table_env.register_function("hash_code", udf(HashCode(), DataTypes.BIGINT(), DataTypes.BIGINT())) +table_env.register_function("hash_code", udf(HashCode(), result_type=DataTypes.BIGINT())) # use the Python function in Python Table API my_table.select("string, bigint, bigint.hash_code(), hash_code(bigint)") @@ -110,29 +110,28 @@ class Add(ScalarFunction): def eval(self, i, j): return i + j -add = udf(Add(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT()) +add = udf(Add(), result_type=DataTypes.BIGINT()) # option 2: Python function -@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT()) +@udf(result_type=DataTypes.BIGINT()) def add(i, j): return i + j # option 3: lambda function -add = udf(lambda i, j: i + j, [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT()) +add = udf(lambda i, j: i + j, result_type=DataTypes.BIGINT()) # option 4: callable function class CallableAdd(object): def __call__(self, i, j): return i + j -add = udf(CallableAdd(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT()) +add = udf(CallableAdd(), result_type=DataTypes.BIGINT()) # option 5: partial function def partial_add(i, j, k): return i + j + k -add = udf(functools.partial(partial_add, k=1), [DataTypes.BIGINT(), DataTypes.BIGINT()], - DataTypes.BIGINT()) +add = udf(functools.partial(partial_add, k=1), result_type=DataTypes.BIGINT()) # register the Python function table_env.register_function("add", add) @@ -163,7 +162,7 @@ my_table = ... # type: Table, table schema: [a: String] table_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", '80m') # register the Python Table Function -table_env.register_function("split", udtf(Split(), DataTypes.STRING(), [DataTypes.STRING(), DataTypes.INT()])) +table_env.register_function("split", udtf(Split(), result_types=[DataTypes.STRING(), DataTypes.INT()])) # use the Python Table Function in Python Table API my_table.join_lateral("split(a) as (word, length)") @@ -231,18 +230,18 @@ Like Python scalar functions, you can use the above five ways to define Python T {% highlight python %} # option 1: generator function -@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT()) +@udtf(result_types=DataTypes.BIGINT()) def generator_func(x): yield 1 yield 2 # option 2: return iterator -@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT()) +@udtf(result_types=DataTypes.BIGINT()) def iterator_func(x): return range(5) # option 3: return iterable -@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT()) +@udtf(result_types=DataTypes.BIGINT()) def iterable_func(x): result = [1, 2, 3] return result diff --git a/docs/dev/table/python/vectorized_python_udfs.md b/docs/dev/table/python/vectorized_python_udfs.md index ee5f03f937c1e..d4674386bf848 100644 --- a/docs/dev/table/python/vectorized_python_udfs.md +++ b/docs/dev/table/python/vectorized_python_udfs.md @@ -48,7 +48,7 @@ The following example shows how to define your own vectorized Python scalar func and use it in a query: {% highlight python %} -@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT(), udf_type="pandas") +@udf(result_type=DataTypes.BIGINT(), udf_type="pandas") def add(i, j): return i + j diff --git a/docs/dev/table/python/vectorized_python_udfs.zh.md b/docs/dev/table/python/vectorized_python_udfs.zh.md index b2d2ed922f958..c6381c1225086 100644 --- a/docs/dev/table/python/vectorized_python_udfs.zh.md +++ b/docs/dev/table/python/vectorized_python_udfs.zh.md @@ -48,7 +48,7 @@ The following example shows how to define your own vectorized Python scalar func and use it in a query: {% highlight python %} -@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT(), udf_type="pandas") +@udf(result_type=DataTypes.BIGINT(), udf_type="pandas") def add(i, j): return i + j diff --git a/flink-python/pyflink/table/table_environment.py b/flink-python/pyflink/table/table_environment.py index fbdead069acc0..aa05b13f0dce4 100644 --- a/flink-python/pyflink/table/table_environment.py +++ b/flink-python/pyflink/table/table_environment.py @@ -221,10 +221,9 @@ def create_temporary_system_function(self, name: str, :: >>> table_env.create_temporary_system_function( - ... "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())) + ... "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT())) - >>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], - ... result_type=DataTypes.BIGINT()) + >>> @udf(result_type=DataTypes.BIGINT()) ... def add(i, j): ... return i + j >>> table_env.create_temporary_system_function("add", add) @@ -233,7 +232,7 @@ def create_temporary_system_function(self, name: str, ... def eval(self, i): ... return i - 1 >>> table_env.create_temporary_system_function( - ... "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) + ... "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT())) :param name: The name under which the function will be registered globally. :param function: The function class containing the implementation. The function must have a @@ -356,10 +355,9 @@ def create_temporary_function(self, path: str, function: UserDefinedFunctionWrap :: >>> table_env.create_temporary_function( - ... "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())) + ... "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT())) - >>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], - ... result_type=DataTypes.BIGINT()) + >>> @udf(result_type=DataTypes.BIGINT()) ... def add(i, j): ... return i + j >>> table_env.create_temporary_function("add", add) @@ -368,7 +366,7 @@ def create_temporary_function(self, path: str, function: UserDefinedFunctionWrap ... def eval(self, i): ... return i - 1 >>> table_env.create_temporary_function( - ... "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) + ... "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT())) :param path: The path under which the function will be registered. See also the :class:`~pyflink.table.TableEnvironment` class description for @@ -1101,10 +1099,9 @@ def register_function(self, name, function): :: >>> table_env.register_function( - ... "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())) + ... "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT())) - >>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], - ... result_type=DataTypes.BIGINT()) + >>> @udf(result_type=DataTypes.BIGINT()) ... def add(i, j): ... return i + j >>> table_env.register_function("add", add) @@ -1113,7 +1110,7 @@ def register_function(self, name, function): ... def eval(self, i): ... return i - 1 >>> table_env.register_function( - ... "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) + ... "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT())) :param name: The name under which the function is registered. :type name: str diff --git a/flink-python/pyflink/table/tests/test_dependency.py b/flink-python/pyflink/table/tests/test_dependency.py index 4b63784b5226a..2e5abb08f4f18 100644 --- a/flink-python/pyflink/table/tests/test_dependency.py +++ b/flink-python/pyflink/table/tests/test_dependency.py @@ -45,8 +45,7 @@ def plus_two(i): from test_dependency_manage_lib import add_two return add_two(i) - self.t_env.register_function("add_two", udf(plus_two, DataTypes.BIGINT(), - DataTypes.BIGINT())) + self.t_env.register_function("add_two", udf(plus_two, result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) @@ -77,8 +76,7 @@ def plus_two(i): from test_dependency_manage_lib import add_two return add_two(i) - self.t_env.register_function("add_two", udf(plus_two, DataTypes.BIGINT(), - DataTypes.BIGINT())) + self.t_env.register_function("add_two", udf(plus_two, result_type=DataTypes.BIGINT())) t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])\ .select("add_two(a), a") @@ -107,8 +105,7 @@ def check_requirements(i): return i self.t_env.register_function("check_requirements", - udf(check_requirements, DataTypes.BIGINT(), - DataTypes.BIGINT())) + udf(check_requirements, result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) @@ -153,9 +150,7 @@ def add_one(i): from python_package1 import plus return plus(i, 1) - self.t_env.register_function("add_one", - udf(add_one, DataTypes.BIGINT(), - DataTypes.BIGINT())) + self.t_env.register_function("add_one", udf(add_one, result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) @@ -181,8 +176,7 @@ def add_from_file(i): return i + int(f.read()) self.t_env.register_function("add_from_file", - udf(add_from_file, DataTypes.BIGINT(), - DataTypes.BIGINT())) + udf(add_from_file, result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) @@ -207,8 +201,7 @@ def check_python_exec(i): return i self.t_env.register_function("check_python_exec", - udf(check_python_exec, DataTypes.BIGINT(), - DataTypes.BIGINT())) + udf(check_python_exec, result_type=DataTypes.BIGINT())) def check_pyflink_gateway_disabled(i): try: @@ -222,8 +215,8 @@ def check_pyflink_gateway_disabled(i): return i self.t_env.register_function("check_pyflink_gateway_disabled", - udf(check_pyflink_gateway_disabled, DataTypes.BIGINT(), - DataTypes.BIGINT())) + udf(check_pyflink_gateway_disabled, + result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) diff --git a/flink-python/pyflink/table/tests/test_pandas_udf.py b/flink-python/pyflink/table/tests/test_pandas_udf.py index 03249ae490d5f..66941aa69eed6 100644 --- a/flink-python/pyflink/table/tests/test_pandas_udf.py +++ b/flink-python/pyflink/table/tests/test_pandas_udf.py @@ -34,7 +34,7 @@ class PandasUDFTests(unittest.TestCase): def test_non_exist_udf_type(self): with self.assertRaisesRegex(ValueError, 'The udf_type must be one of \'general, pandas\''): - udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), udf_type="non-exist") + udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), udf_type="non-exist") class PandasUDFITTests(object): @@ -43,13 +43,13 @@ def test_basic_functionality(self): # pandas UDF self.t_env.register_function( "add_one", - udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), udf_type="pandas")) + udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), udf_type="pandas")) self.t_env.register_function("add", add) # general Python UDF self.t_env.register_function( - "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) + "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b', 'c', 'd'], @@ -173,77 +173,72 @@ def row_func(row_param): self.t_env.register_function( "tinyint_func", - udf(tinyint_func, [DataTypes.TINYINT()], DataTypes.TINYINT(), udf_type="pandas")) + udf(tinyint_func, result_type=DataTypes.TINYINT(), udf_type="pandas")) self.t_env.register_function( "smallint_func", - udf(smallint_func, [DataTypes.SMALLINT()], DataTypes.SMALLINT(), udf_type="pandas")) + udf(smallint_func, result_type=DataTypes.SMALLINT(), udf_type="pandas")) self.t_env.register_function( "int_func", - udf(int_func, [DataTypes.INT()], DataTypes.INT(), udf_type="pandas")) + udf(int_func, result_type=DataTypes.INT(), udf_type="pandas")) self.t_env.register_function( "bigint_func", - udf(bigint_func, [DataTypes.BIGINT()], DataTypes.BIGINT(), udf_type="pandas")) + udf(bigint_func, result_type=DataTypes.BIGINT(), udf_type="pandas")) self.t_env.register_function( "boolean_func", - udf(boolean_func, [DataTypes.BOOLEAN()], DataTypes.BOOLEAN(), udf_type="pandas")) + udf(boolean_func, result_type=DataTypes.BOOLEAN(), udf_type="pandas")) self.t_env.register_function( "float_func", - udf(float_func, [DataTypes.FLOAT()], DataTypes.FLOAT(), udf_type="pandas")) + udf(float_func, result_type=DataTypes.FLOAT(), udf_type="pandas")) self.t_env.register_function( "double_func", - udf(double_func, [DataTypes.DOUBLE()], DataTypes.DOUBLE(), udf_type="pandas")) + udf(double_func, result_type=DataTypes.DOUBLE(), udf_type="pandas")) self.t_env.register_function( "varchar_func", - udf(varchar_func, [DataTypes.STRING()], DataTypes.STRING(), udf_type="pandas")) + udf(varchar_func, result_type=DataTypes.STRING(), udf_type="pandas")) self.t_env.register_function( "varbinary_func", - udf(varbinary_func, [DataTypes.BYTES()], DataTypes.BYTES(), udf_type="pandas")) + udf(varbinary_func, result_type=DataTypes.BYTES(), udf_type="pandas")) self.t_env.register_function( "decimal_func", - udf(decimal_func, [DataTypes.DECIMAL(38, 18)], DataTypes.DECIMAL(38, 18), - udf_type="pandas")) + udf(decimal_func, result_type=DataTypes.DECIMAL(38, 18), udf_type="pandas")) self.t_env.register_function( "date_func", - udf(date_func, [DataTypes.DATE()], DataTypes.DATE(), udf_type="pandas")) + udf(date_func, result_type=DataTypes.DATE(), udf_type="pandas")) self.t_env.register_function( "time_func", - udf(time_func, [DataTypes.TIME()], DataTypes.TIME(), udf_type="pandas")) + udf(time_func, result_type=DataTypes.TIME(), udf_type="pandas")) self.t_env.register_function( "timestamp_func", - udf(timestamp_func, [DataTypes.TIMESTAMP(3)], DataTypes.TIMESTAMP(3), - udf_type="pandas")) + udf(timestamp_func, result_type=DataTypes.TIMESTAMP(3), udf_type="pandas")) self.t_env.register_function( "array_str_func", - udf(array_func, [DataTypes.ARRAY(DataTypes.STRING())], - DataTypes.ARRAY(DataTypes.STRING()), udf_type="pandas")) + udf(array_func, result_type=DataTypes.ARRAY(DataTypes.STRING()), udf_type="pandas")) self.t_env.register_function( "array_timestamp_func", - udf(array_func, [DataTypes.ARRAY(DataTypes.TIMESTAMP(3))], - DataTypes.ARRAY(DataTypes.TIMESTAMP(3)), udf_type="pandas")) + udf(array_func, result_type=DataTypes.ARRAY(DataTypes.TIMESTAMP(3)), udf_type="pandas")) self.t_env.register_function( "array_int_func", - udf(array_func, [DataTypes.ARRAY(DataTypes.INT())], - DataTypes.ARRAY(DataTypes.INT()), udf_type="pandas")) + udf(array_func, result_type=DataTypes.ARRAY(DataTypes.INT()), udf_type="pandas")) self.t_env.register_function( "nested_array_func", - udf(nested_array_func, [DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))], - DataTypes.ARRAY(DataTypes.STRING()), udf_type="pandas")) + udf(nested_array_func, + result_type=DataTypes.ARRAY(DataTypes.STRING()), udf_type="pandas")) row_type = DataTypes.ROW( [DataTypes.FIELD("f1", DataTypes.INT()), @@ -252,7 +247,7 @@ def row_func(row_param): DataTypes.FIELD("f4", DataTypes.ARRAY(DataTypes.INT()))]) self.t_env.register_function( "row_func", - udf(row_func, [row_type], row_type, udf_type="pandas")) + udf(row_func, result_type=row_type, udf_type="pandas")) table_sink = source_sink_utils.TestAppendSink( ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', @@ -350,8 +345,7 @@ def local_zoned_timestamp_func(local_zoned_timestamp_param): self.t_env.register_function( "local_zoned_timestamp_func", udf(local_zoned_timestamp_func, - [DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)], - DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3), + result_type=DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3), udf_type="pandas")) table_sink = source_sink_utils.TestAppendSink( @@ -379,13 +373,13 @@ class BatchPandasUDFITTests(PyFlinkBatchTableTestCase): def test_basic_functionality(self): self.t_env.register_function( "add_one", - udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), udf_type="pandas")) + udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), udf_type="pandas")) self.t_env.register_function("add", add) # general Python UDF self.t_env.register_function( - "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) + "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT())) t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c']) t = t.where("add_one(b) <= 3") \ @@ -406,8 +400,7 @@ class BlinkStreamPandasUDFITTests(PandasUDFITTests, pass -@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT(), - udf_type='pandas') +@udf(result_type=DataTypes.BIGINT(), udf_type='pandas') def add(i, j): return i + j diff --git a/flink-python/pyflink/table/tests/test_udf.py b/flink-python/pyflink/table/tests/test_udf.py index fbe0bd8716076..c9ca376581807 100644 --- a/flink-python/pyflink/table/tests/test_udf.py +++ b/flink-python/pyflink/table/tests/test_udf.py @@ -35,18 +35,18 @@ def test_scalar_function(self): self.t_env.get_config().get_configuration().set_string('python.metric.enabled', 'false') # test lambda function self.t_env.register_function( - "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())) + "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT())) # test Python ScalarFunction self.t_env.register_function( - "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) + "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT())) # test Python function self.t_env.register_function("add", add) # test callable function self.t_env.register_function( - "add_one_callable", udf(CallablePlus(), DataTypes.BIGINT(), DataTypes.BIGINT())) + "add_one_callable", udf(CallablePlus(), result_type=DataTypes.BIGINT())) def partial_func(col, param): return col + param @@ -55,7 +55,7 @@ def partial_func(col, param): import functools self.t_env.register_function( "add_one_partial", - udf(functools.partial(partial_func, param=1), DataTypes.BIGINT(), DataTypes.BIGINT())) + udf(functools.partial(partial_func, param=1), result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b', 'c', 'd', 'e', 'f'], @@ -74,9 +74,9 @@ def partial_func(col, param): def test_chaining_scalar_function(self): self.t_env.register_function( - "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())) + "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT())) self.t_env.register_function( - "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) + "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT())) self.t_env.register_function("add", add) table_sink = source_sink_utils.TestAppendSink( @@ -95,7 +95,7 @@ def test_udf_in_join_condition(self): t1 = self.t_env.from_elements([(2, "Hi")], ['a', 'b']) t2 = self.t_env.from_elements([(2, "Flink")], ['c', 'd']) - self.t_env.register_function("f", udf(lambda i: i, DataTypes.BIGINT(), DataTypes.BIGINT())) + self.t_env.register_function("f", udf(lambda i: i, result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b', 'c', 'd'], @@ -111,7 +111,7 @@ def test_udf_in_join_condition_2(self): t1 = self.t_env.from_elements([(1, "Hi"), (2, "Hi")], ['a', 'b']) t2 = self.t_env.from_elements([(2, "Flink")], ['c', 'd']) - self.t_env.register_function("f", udf(lambda i: i, DataTypes.BIGINT(), DataTypes.BIGINT())) + self.t_env.register_function("f", udf(lambda i: i, result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b', 'c', 'd'], @@ -177,26 +177,11 @@ def udf_with_constant_params(p, null_param, tinyint_param, smallint_param, int_p self.t_env.register_function("udf_with_constant_params", udf(udf_with_constant_params, - input_types=[DataTypes.BIGINT(), - DataTypes.BIGINT(), - DataTypes.TINYINT(), - DataTypes.SMALLINT(), - DataTypes.INT(), - DataTypes.BIGINT(), - DataTypes.DECIMAL(38, 18), - DataTypes.FLOAT(), - DataTypes.DOUBLE(), - DataTypes.BOOLEAN(), - DataTypes.STRING(), - DataTypes.DATE(), - DataTypes.TIME(), - DataTypes.TIMESTAMP(3)], result_type=DataTypes.BIGINT())) self.t_env.register_function( "udf_with_all_constant_params", udf(lambda i, j: i + j, - [DataTypes.BIGINT(), DataTypes.BIGINT()], - DataTypes.BIGINT())) + result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink(['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) @@ -229,7 +214,7 @@ def udf_with_constant_params(p, null_param, tinyint_param, smallint_param, int_p def test_overwrite_builtin_function(self): self.t_env.register_function( "plus", udf(lambda i, j: i + j - 1, - [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())) + result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink(['a'], [DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) @@ -243,7 +228,7 @@ def test_overwrite_builtin_function(self): def test_open(self): self.t_env.get_config().get_configuration().set_string('python.metric.enabled', 'true') self.t_env.register_function( - "subtract", udf(Subtract(), DataTypes.BIGINT(), DataTypes.BIGINT())) + "subtract", udf(Subtract(), result_type=DataTypes.BIGINT())) table_sink = source_sink_utils.TestAppendSink( ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_table_sink("Results", table_sink) @@ -256,9 +241,9 @@ def test_open(self): def test_udf_without_arguments(self): self.t_env.register_function("one", udf( - lambda: 1, input_types=[], result_type=DataTypes.BIGINT(), deterministic=True)) + lambda: 1, result_type=DataTypes.BIGINT(), deterministic=True)) self.t_env.register_function("two", udf( - lambda: 2, input_types=[], result_type=DataTypes.BIGINT(), deterministic=False)) + lambda: 2, result_type=DataTypes.BIGINT(), deterministic=False)) table_sink = source_sink_utils.TestAppendSink(['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) @@ -363,59 +348,56 @@ def decimal_cut_func(decimal_param): return decimal_param self.t_env.register_function( - "boolean_func", udf(boolean_func, [DataTypes.BOOLEAN()], DataTypes.BOOLEAN())) + "boolean_func", udf(boolean_func, result_type=DataTypes.BOOLEAN())) self.t_env.register_function( - "tinyint_func", udf(tinyint_func, [DataTypes.TINYINT()], DataTypes.TINYINT())) + "tinyint_func", udf(tinyint_func, result_type=DataTypes.TINYINT())) self.t_env.register_function( - "smallint_func", udf(smallint_func, [DataTypes.SMALLINT()], DataTypes.SMALLINT())) + "smallint_func", udf(smallint_func, result_type=DataTypes.SMALLINT())) self.t_env.register_function( - "int_func", udf(int_func, [DataTypes.INT()], DataTypes.INT())) + "int_func", udf(int_func, result_type=DataTypes.INT())) self.t_env.register_function( - "bigint_func", udf(bigint_func, [DataTypes.BIGINT()], DataTypes.BIGINT())) + "bigint_func", udf(bigint_func, result_type=DataTypes.BIGINT())) self.t_env.register_function( - "bigint_func_none", udf(bigint_func_none, [DataTypes.BIGINT()], DataTypes.BIGINT())) + "bigint_func_none", udf(bigint_func_none, result_type=DataTypes.BIGINT())) self.t_env.register_function( - "float_func", udf(float_func, [DataTypes.FLOAT()], DataTypes.FLOAT())) + "float_func", udf(float_func, result_type=DataTypes.FLOAT())) self.t_env.register_function( - "double_func", udf(double_func, [DataTypes.DOUBLE()], DataTypes.DOUBLE())) + "double_func", udf(double_func, result_type=DataTypes.DOUBLE())) self.t_env.register_function( - "bytes_func", udf(bytes_func, [DataTypes.BYTES()], DataTypes.BYTES())) + "bytes_func", udf(bytes_func, result_type=DataTypes.BYTES())) self.t_env.register_function( - "str_func", udf(str_func, [DataTypes.STRING()], DataTypes.STRING())) + "str_func", udf(str_func, result_type=DataTypes.STRING())) self.t_env.register_function( - "date_func", udf(date_func, [DataTypes.DATE()], DataTypes.DATE())) + "date_func", udf(date_func, result_type=DataTypes.DATE())) self.t_env.register_function( - "time_func", udf(time_func, [DataTypes.TIME()], DataTypes.TIME())) + "time_func", udf(time_func, result_type=DataTypes.TIME())) self.t_env.register_function( - "timestamp_func", udf(timestamp_func, [DataTypes.TIMESTAMP(3)], DataTypes.TIMESTAMP(3))) + "timestamp_func", udf(timestamp_func, result_type=DataTypes.TIMESTAMP(3))) self.t_env.register_function( - "array_func", udf(array_func, [DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.BIGINT()))], - DataTypes.ARRAY(DataTypes.BIGINT()))) + "array_func", udf(array_func, result_type=DataTypes.ARRAY(DataTypes.BIGINT()))) self.t_env.register_function( - "map_func", udf(map_func, [DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING())], - DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING()))) + "map_func", udf(map_func, + result_type=DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING()))) self.t_env.register_function( - "decimal_func", udf(decimal_func, [DataTypes.DECIMAL(38, 18)], - DataTypes.DECIMAL(38, 18))) + "decimal_func", udf(decimal_func, result_type=DataTypes.DECIMAL(38, 18))) self.t_env.register_function( - "decimal_cut_func", udf(decimal_cut_func, [DataTypes.DECIMAL(38, 18)], - DataTypes.DECIMAL(38, 18))) + "decimal_cut_func", udf(decimal_cut_func, result_type=DataTypes.DECIMAL(38, 18))) table_sink = source_sink_utils.TestAppendSink( ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q'], @@ -480,9 +462,9 @@ def test_create_and_drop_function(self): t_env = self.t_env t_env.create_temporary_system_function( - "add_one_func", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())) + "add_one_func", udf(lambda i: i + 1, result_type=DataTypes.BIGINT())) t_env.create_temporary_function( - "subtract_one_func", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) + "subtract_one_func", udf(SubtractOne(), result_type=DataTypes.BIGINT())) self.assert_equals(t_env.list_user_defined_functions(), ['add_one_func', 'subtract_one_func']) @@ -505,9 +487,9 @@ class PyFlinkBatchUserDefinedFunctionTests(PyFlinkBatchTableTestCase): def test_chaining_scalar_function(self): self.t_env.register_function( - "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())) + "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT())) self.t_env.register_function( - "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) + "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT())) self.t_env.register_function("add", add) t = self.t_env.from_elements([(1, 2, 1), (2, 5, 2), (3, 1, 3)], ['a', 'b', 'c'])\ @@ -520,43 +502,42 @@ def test_chaining_scalar_function(self): class PyFlinkBlinkStreamUserDefinedFunctionTests(UserDefinedFunctionTests, PyFlinkBlinkStreamTableTestCase): def test_deterministic(self): - add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()) + add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT()) self.assertTrue(add_one._deterministic) - add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), deterministic=False) + add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), deterministic=False) self.assertFalse(add_one._deterministic) - subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()) + subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT()) self.assertTrue(subtract_one._deterministic) with self.assertRaises(ValueError, msg="Inconsistent deterministic: False and True"): - udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT(), deterministic=False) + udf(SubtractOne(), result_type=DataTypes.BIGINT(), deterministic=False) self.assertTrue(add._deterministic) - @udf(input_types=DataTypes.BIGINT(), result_type=DataTypes.BIGINT(), deterministic=False) + @udf(result_type=DataTypes.BIGINT(), deterministic=False) def non_deterministic_udf(i): return i self.assertFalse(non_deterministic_udf._deterministic) def test_name(self): - add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()) + add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT()) self.assertEqual("", add_one._name) - add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), name="add_one") + add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), name="add_one") self.assertEqual("add_one", add_one._name) - subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()) + subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT()) self.assertEqual("SubtractOne", subtract_one._name) - subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT(), - name="subtract_one") + subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT(), name="subtract_one") self.assertEqual("subtract_one", subtract_one._name) self.assertEqual("add", add._name) - @udf(input_types=DataTypes.BIGINT(), result_type=DataTypes.BIGINT(), name="named") + @udf(result_type=DataTypes.BIGINT(), name="named") def named_udf(i): return i @@ -597,8 +578,7 @@ def local_zoned_timestamp_func(local_zoned_timestamp_param): self.t_env.register_function( "local_zoned_timestamp_func", udf(local_zoned_timestamp_func, - [DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)], - DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3))) + result_type=DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3))) table_sink = source_sink_utils.TestAppendSink( ['a'], [DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)]) @@ -620,6 +600,7 @@ class PyFlinkBlinkBatchUserDefinedFunctionTests(UserDefinedFunctionTests, pass +# test specify the input_types @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT()) def add(i, j): return i + j diff --git a/flink-python/pyflink/table/tests/test_udtf.py b/flink-python/pyflink/table/tests/test_udtf.py index 76223a62fa2c5..829b891ae3106 100644 --- a/flink-python/pyflink/table/tests/test_udtf.py +++ b/flink-python/pyflink/table/tests/test_udtf.py @@ -31,14 +31,12 @@ def test_table_function(self): [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_function( - "multi_emit", udtf(MultiEmit(), [DataTypes.BIGINT(), DataTypes.BIGINT()], - [DataTypes.BIGINT(), DataTypes.BIGINT()])) + "multi_emit", udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()])) self.t_env.register_function("condition_multi_emit", condition_multi_emit) self.t_env.register_function( - "multi_num", udf(MultiNum(), [DataTypes.BIGINT()], - DataTypes.BIGINT())) + "multi_num", udf(MultiNum(), result_type=DataTypes.BIGINT())) t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c']) t = t.join_lateral("multi_emit(a, multi_num(b)) as (x, y)") \ @@ -55,8 +53,7 @@ def test_table_function_with_sql_query(self): [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()]) self.t_env.register_function( - "multi_emit", udtf(MultiEmit(), [DataTypes.BIGINT(), DataTypes.BIGINT()], - [DataTypes.BIGINT(), DataTypes.BIGINT()])) + "multi_emit", udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()])) t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c']) self.t_env.register_table("MyTable", t) @@ -115,6 +112,7 @@ def eval(self, x, y): yield x, i +# test specify the input_types @udtf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_types=DataTypes.BIGINT()) def condition_multi_emit(x, y): diff --git a/flink-python/pyflink/table/udf.py b/flink-python/pyflink/table/udf.py index b6b3e7ec9236b..3ad002ffa44bf 100644 --- a/flink-python/pyflink/table/udf.py +++ b/flink-python/pyflink/table/udf.py @@ -160,14 +160,15 @@ def __init__(self, func, input_types, deterministic=None, name=None): "Invalid function: not a function or callable (__call__ is not defined): {0}" .format(type(func))) - if not isinstance(input_types, collections.Iterable): - input_types = [input_types] + if input_types is not None: + if not isinstance(input_types, collections.Iterable): + input_types = [input_types] - for input_type in input_types: - if not isinstance(input_type, DataType): - raise TypeError( - "Invalid input_type: input_type should be DataType but contains {}".format( - input_type)) + for input_type in input_types: + if not isinstance(input_type, DataType): + raise TypeError( + "Invalid input_type: input_type should be DataType but contains {}".format( + input_type)) self._func = func self._input_types = input_types @@ -228,8 +229,11 @@ def get_python_function_kind(udf_type): import cloudpickle serialized_func = cloudpickle.dumps(func) - j_input_types = utils.to_jarray(gateway.jvm.TypeInformation, - [_to_java_type(i) for i in self._input_types]) + if self._input_types is not None: + j_input_types = utils.to_jarray( + gateway.jvm.TypeInformation, [_to_java_type(i) for i in self._input_types]) + else: + j_input_types = None j_result_type = _to_java_type(self._result_type) j_function_kind = get_python_function_kind(self._udf_type) PythonScalarFunction = gateway.jvm \ @@ -280,8 +284,11 @@ def _create_judtf(self): serialized_func = cloudpickle.dumps(func) gateway = get_gateway() - j_input_types = utils.to_jarray(gateway.jvm.TypeInformation, - [_to_java_type(i) for i in self._input_types]) + if self._input_types is not None: + j_input_types = utils.to_jarray( + gateway.jvm.TypeInformation, [_to_java_type(i) for i in self._input_types]) + else: + j_input_types = None j_result_types = utils.to_jarray(gateway.jvm.TypeInformation, [_to_java_type(i) for i in self._result_types]) @@ -327,8 +334,8 @@ def udf(f=None, input_types=None, result_type=None, deterministic=None, name=Non >>> add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()) - >>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], - ... result_type=DataTypes.BIGINT()) + >>> # The input_types is optional. + >>> @udf(result_type=DataTypes.BIGINT()) ... def add(i, j): ... return i + j @@ -339,7 +346,7 @@ def udf(f=None, input_types=None, result_type=None, deterministic=None, name=Non :param f: lambda function or user-defined function. :type f: function or UserDefinedFunction or type - :param input_types: the input data types. + :param input_types: optional, the input data types. :type input_types: list[DataType] or DataType :param result_type: the result data type. :type result_type: DataType @@ -375,8 +382,8 @@ def udtf(f=None, input_types=None, result_types=None, deterministic=None, name=N Example: :: - >>> @udtf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], - ... result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]) + >>> # The input_types is optional. + >>> @udtf(result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]) ... def range_emit(s, e): ... for i in range(e): ... yield s, i @@ -388,7 +395,7 @@ def udtf(f=None, input_types=None, result_types=None, deterministic=None, name=N :param f: user-defined table function. :type f: function or UserDefinedFunction or type - :param input_types: the input data types. + :param input_types: optional, the input data types. :type input_types: list[DataType] or DataType :param result_types: the result data types. :type result_types: list[DataType] or DataType diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonScalarFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonScalarFunction.java index 72df5e3b2948a..d00eb5132207d 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonScalarFunction.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonScalarFunction.java @@ -91,7 +91,11 @@ public boolean isDeterministic() { @Override public TypeInformation[] getParameterTypes(Class[] signature) { - return inputTypes; + if (inputTypes != null) { + return inputTypes; + } else { + return super.getParameterTypes(signature); + } } @Override diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableFunction.java index 3560170ba2066..2ed57aff5df68 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableFunction.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableFunction.java @@ -93,7 +93,11 @@ public boolean isDeterministic() { @Override public TypeInformation[] getParameterTypes(Class[] signature) { - return inputTypes; + if (inputTypes != null) { + return inputTypes; + } else { + return super.getParameterTypes(signature); + } } @Override