Skip to content

Commit

Permalink
[FLINK-18463][python] Make the "input_types" parameter of the Python …
Browse files Browse the repository at this point in the history
…UDF/UDTF decorator optional.

This closes apache#12921.
  • Loading branch information
WeiZhong94 authored and dianfu committed Jul 17, 2020
1 parent 9c4a984 commit ef46460
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 175 deletions.
21 changes: 10 additions & 11 deletions docs/dev/table/python/python_udfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ table_env = BatchTableEnvironment.create(env)
table_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", '80m')

# register the Python function
table_env.register_function("hash_code", udf(HashCode(), DataTypes.BIGINT(), DataTypes.BIGINT()))
table_env.register_function("hash_code", udf(HashCode(), result_type=DataTypes.BIGINT()))

# use the Python function in Python Table API
my_table.select("string, bigint, bigint.hash_code(), hash_code(bigint)")
Expand Down Expand Up @@ -110,29 +110,28 @@ class Add(ScalarFunction):
def eval(self, i, j):
return i + j

add = udf(Add(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
add = udf(Add(), result_type=DataTypes.BIGINT())

# option 2: Python function
@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT())
@udf(result_type=DataTypes.BIGINT())
def add(i, j):
return i + j

# option 3: lambda function
add = udf(lambda i, j: i + j, [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
add = udf(lambda i, j: i + j, result_type=DataTypes.BIGINT())

# option 4: callable function
class CallableAdd(object):
def __call__(self, i, j):
return i + j

add = udf(CallableAdd(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
add = udf(CallableAdd(), result_type=DataTypes.BIGINT())

# option 5: partial function
def partial_add(i, j, k):
return i + j + k

add = udf(functools.partial(partial_add, k=1), [DataTypes.BIGINT(), DataTypes.BIGINT()],
DataTypes.BIGINT())
add = udf(functools.partial(partial_add, k=1), result_type=DataTypes.BIGINT())

# register the Python function
table_env.register_function("add", add)
Expand Down Expand Up @@ -163,7 +162,7 @@ my_table = ... # type: Table, table schema: [a: String]
table_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", '80m')

# register the Python Table Function
table_env.register_function("split", udtf(Split(), DataTypes.STRING(), [DataTypes.STRING(), DataTypes.INT()]))
table_env.register_function("split", udtf(Split(), result_types=[DataTypes.STRING(), DataTypes.INT()]))

# use the Python Table Function in Python Table API
my_table.join_lateral("split(a) as (word, length)")
Expand Down Expand Up @@ -231,18 +230,18 @@ Like Python scalar functions, you can use the above five ways to define Python T

{% highlight python %}
# option 1: generator function
@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT())
@udtf(result_types=DataTypes.BIGINT())
def generator_func(x):
yield 1
yield 2

# option 2: return iterator
@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT())
@udtf(result_types=DataTypes.BIGINT())
def iterator_func(x):
return range(5)

# option 3: return iterable
@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT())
@udtf(result_types=DataTypes.BIGINT())
def iterable_func(x):
result = [1, 2, 3]
return result
Expand Down
21 changes: 10 additions & 11 deletions docs/dev/table/python/python_udfs.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ table_env = BatchTableEnvironment.create(env)
table_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", '80m')

# register the Python function
table_env.register_function("hash_code", udf(HashCode(), DataTypes.BIGINT(), DataTypes.BIGINT()))
table_env.register_function("hash_code", udf(HashCode(), result_type=DataTypes.BIGINT()))

# use the Python function in Python Table API
my_table.select("string, bigint, bigint.hash_code(), hash_code(bigint)")
Expand Down Expand Up @@ -110,29 +110,28 @@ class Add(ScalarFunction):
def eval(self, i, j):
return i + j

add = udf(Add(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
add = udf(Add(), result_type=DataTypes.BIGINT())

# option 2: Python function
@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT())
@udf(result_type=DataTypes.BIGINT())
def add(i, j):
return i + j

# option 3: lambda function
add = udf(lambda i, j: i + j, [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
add = udf(lambda i, j: i + j, result_type=DataTypes.BIGINT())

# option 4: callable function
class CallableAdd(object):
def __call__(self, i, j):
return i + j

add = udf(CallableAdd(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
add = udf(CallableAdd(), result_type=DataTypes.BIGINT())

# option 5: partial function
def partial_add(i, j, k):
return i + j + k

add = udf(functools.partial(partial_add, k=1), [DataTypes.BIGINT(), DataTypes.BIGINT()],
DataTypes.BIGINT())
add = udf(functools.partial(partial_add, k=1), result_type=DataTypes.BIGINT())

# register the Python function
table_env.register_function("add", add)
Expand Down Expand Up @@ -163,7 +162,7 @@ my_table = ... # type: Table, table schema: [a: String]
table_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", '80m')

# register the Python Table Function
table_env.register_function("split", udtf(Split(), DataTypes.STRING(), [DataTypes.STRING(), DataTypes.INT()]))
table_env.register_function("split", udtf(Split(), result_types=[DataTypes.STRING(), DataTypes.INT()]))

# use the Python Table Function in Python Table API
my_table.join_lateral("split(a) as (word, length)")
Expand Down Expand Up @@ -231,18 +230,18 @@ Like Python scalar functions, you can use the above five ways to define Python T

{% highlight python %}
# option 1: generator function
@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT())
@udtf(result_types=DataTypes.BIGINT())
def generator_func(x):
yield 1
yield 2

# option 2: return iterator
@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT())
@udtf(result_types=DataTypes.BIGINT())
def iterator_func(x):
return range(5)

# option 3: return iterable
@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT())
@udtf(result_types=DataTypes.BIGINT())
def iterable_func(x):
result = [1, 2, 3]
return result
Expand Down
2 changes: 1 addition & 1 deletion docs/dev/table/python/vectorized_python_udfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ The following example shows how to define your own vectorized Python scalar func
and use it in a query:

{% highlight python %}
@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT(), udf_type="pandas")
@udf(result_type=DataTypes.BIGINT(), udf_type="pandas")
def add(i, j):
return i + j

Expand Down
2 changes: 1 addition & 1 deletion docs/dev/table/python/vectorized_python_udfs.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ The following example shows how to define your own vectorized Python scalar func
and use it in a query:

{% highlight python %}
@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT(), udf_type="pandas")
@udf(result_type=DataTypes.BIGINT(), udf_type="pandas")
def add(i, j):
return i + j

Expand Down
21 changes: 9 additions & 12 deletions flink-python/pyflink/table/table_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,9 @@ def create_temporary_system_function(self, name: str,
::
>>> table_env.create_temporary_system_function(
... "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
... "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT()))
>>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
... result_type=DataTypes.BIGINT())
>>> @udf(result_type=DataTypes.BIGINT())
... def add(i, j):
... return i + j
>>> table_env.create_temporary_system_function("add", add)
Expand All @@ -233,7 +232,7 @@ def create_temporary_system_function(self, name: str,
... def eval(self, i):
... return i - 1
>>> table_env.create_temporary_system_function(
... "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
... "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
:param name: The name under which the function will be registered globally.
:param function: The function class containing the implementation. The function must have a
Expand Down Expand Up @@ -356,10 +355,9 @@ def create_temporary_function(self, path: str, function: UserDefinedFunctionWrap
::
>>> table_env.create_temporary_function(
... "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
... "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT()))
>>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
... result_type=DataTypes.BIGINT())
>>> @udf(result_type=DataTypes.BIGINT())
... def add(i, j):
... return i + j
>>> table_env.create_temporary_function("add", add)
Expand All @@ -368,7 +366,7 @@ def create_temporary_function(self, path: str, function: UserDefinedFunctionWrap
... def eval(self, i):
... return i - 1
>>> table_env.create_temporary_function(
... "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
... "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
:param path: The path under which the function will be registered.
See also the :class:`~pyflink.table.TableEnvironment` class description for
Expand Down Expand Up @@ -1101,10 +1099,9 @@ def register_function(self, name, function):
::
>>> table_env.register_function(
... "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
... "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT()))
>>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
... result_type=DataTypes.BIGINT())
>>> @udf(result_type=DataTypes.BIGINT())
... def add(i, j):
... return i + j
>>> table_env.register_function("add", add)
Expand All @@ -1113,7 +1110,7 @@ def register_function(self, name, function):
... def eval(self, i):
... return i - 1
>>> table_env.register_function(
... "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
... "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
:param name: The name under which the function is registered.
:type name: str
Expand Down
23 changes: 8 additions & 15 deletions flink-python/pyflink/table/tests/test_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def plus_two(i):
from test_dependency_manage_lib import add_two
return add_two(i)

self.t_env.register_function("add_two", udf(plus_two, DataTypes.BIGINT(),
DataTypes.BIGINT()))
self.t_env.register_function("add_two", udf(plus_two, result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
Expand Down Expand Up @@ -77,8 +76,7 @@ def plus_two(i):
from test_dependency_manage_lib import add_two
return add_two(i)

self.t_env.register_function("add_two", udf(plus_two, DataTypes.BIGINT(),
DataTypes.BIGINT()))
self.t_env.register_function("add_two", udf(plus_two, result_type=DataTypes.BIGINT()))

t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])\
.select("add_two(a), a")
Expand Down Expand Up @@ -107,8 +105,7 @@ def check_requirements(i):
return i

self.t_env.register_function("check_requirements",
udf(check_requirements, DataTypes.BIGINT(),
DataTypes.BIGINT()))
udf(check_requirements, result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
Expand Down Expand Up @@ -153,9 +150,7 @@ def add_one(i):
from python_package1 import plus
return plus(i, 1)

self.t_env.register_function("add_one",
udf(add_one, DataTypes.BIGINT(),
DataTypes.BIGINT()))
self.t_env.register_function("add_one", udf(add_one, result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
Expand All @@ -181,8 +176,7 @@ def add_from_file(i):
return i + int(f.read())

self.t_env.register_function("add_from_file",
udf(add_from_file, DataTypes.BIGINT(),
DataTypes.BIGINT()))
udf(add_from_file, result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
Expand All @@ -207,8 +201,7 @@ def check_python_exec(i):
return i

self.t_env.register_function("check_python_exec",
udf(check_python_exec, DataTypes.BIGINT(),
DataTypes.BIGINT()))
udf(check_python_exec, result_type=DataTypes.BIGINT()))

def check_pyflink_gateway_disabled(i):
try:
Expand All @@ -222,8 +215,8 @@ def check_pyflink_gateway_disabled(i):
return i

self.t_env.register_function("check_pyflink_gateway_disabled",
udf(check_pyflink_gateway_disabled, DataTypes.BIGINT(),
DataTypes.BIGINT()))
udf(check_pyflink_gateway_disabled,
result_type=DataTypes.BIGINT()))

table_sink = source_sink_utils.TestAppendSink(
['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
Expand Down
Loading

0 comments on commit ef46460

Please sign in to comment.