Skip to content

Commit

Permalink
[FLINK-19233][python] Support distinct and filter keywords on Python …
Browse files Browse the repository at this point in the history
…UDAF

This closes apache#13804.
  • Loading branch information
WeiZhong94 authored and dianfu committed Oct 28, 2020
1 parent 0a14ad1 commit 6ddd2c9
Show file tree
Hide file tree
Showing 14 changed files with 371 additions and 117 deletions.
75 changes: 73 additions & 2 deletions flink-python/pyflink/fn_execution/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.
################################################################################
from abc import ABC, abstractmethod
from typing import List
from typing import List, Dict

from apache_beam.coders import PickleCoder, Coder

Expand Down Expand Up @@ -143,6 +143,19 @@ def __init__(self, state_id, field_index, key_coder, value_coder):
self.value_coder = value_coder


class DistinctViewDescriptor(object):

def __init__(self, input_extractor, filter_args):
self._input_extractor = input_extractor
self._filter_args = filter_args

def get_input_extractor(self):
return self._input_extractor

def get_filter_args(self):
return self._filter_args


class RowKeySelector(object):
"""
A simple key selector used to extract the current key from the input Row according to the
Expand Down Expand Up @@ -287,7 +300,10 @@ def __init__(self,
udfs: List[AggregateFunction],
input_extractors: List,
index_of_count_star: int,
udf_data_view_specs: List[List[DataViewSpec]]):
udf_data_view_specs: List[List[DataViewSpec]],
filter_args: List[int],
distinct_indexes: List[int],
distinct_view_descriptors: Dict[int, DistinctViewDescriptor]):
self._udfs = udfs
self._input_extractors = input_extractors
self._accumulators = None # type: Row
Expand All @@ -297,6 +313,10 @@ def __init__(self,
self._get_value_indexes.remove(index_of_count_star)
self._udf_data_view_specs = udf_data_view_specs
self._udf_data_views = []
self._filter_args = filter_args
self._distinct_indexes = distinct_indexes
self._distinct_view_descriptors = distinct_view_descriptors
self._distinct_data_views = {}

def open(self, state_data_view_store):
for udf in self._udfs:
Expand All @@ -317,17 +337,68 @@ def open(self, state_data_view_store):
PickleCoder(),
PickleCoder())
self._udf_data_views.append(data_views)
for key in self._distinct_view_descriptors.keys():
self._distinct_data_views[key] = state_data_view_store.get_state_map_view(
"agg%ddistinct" % key,
PickleCoder(),
PickleCoder())

def accumulate(self, input_data: Row):
for i in range(len(self._udfs)):
if i in self._distinct_data_views:
if len(self._distinct_view_descriptors[i].get_filter_args()) == 0:
filtered = False
else:
filtered = True
for filter_arg in self._distinct_view_descriptors[i].get_filter_args():
if input_data[filter_arg]:
filtered = False
break
if not filtered:
input_extractor = self._distinct_view_descriptors[i].get_input_extractor()
args = input_extractor(input_data)
if args in self._distinct_data_views[i]:
self._distinct_data_views[i][args] += 1
else:
self._distinct_data_views[i][args] = 1
if self._filter_args[i] >= 0 and not input_data[self._filter_args[i]]:
continue
input_extractor = self._input_extractors[i]
args = input_extractor(input_data)
if self._distinct_indexes[i] >= 0:
if args in self._distinct_data_views[self._distinct_indexes[i]]:
if self._distinct_data_views[self._distinct_indexes[i]][args] > 1:
continue
else:
raise Exception(
"The args are not in the distinct data view, this should not happen.")
self._udfs[i].accumulate(self._accumulators[i], *args)

def retract(self, input_data: Row):
for i in range(len(self._udfs)):
if i in self._distinct_data_views:
if len(self._distinct_view_descriptors[i].get_filter_args()) == 0:
filtered = False
else:
filtered = True
for filter_arg in self._distinct_view_descriptors[i].get_filter_args():
if input_data[filter_arg]:
filtered = False
break
if not filtered:
input_extractor = self._distinct_view_descriptors[i].get_input_extractor()
args = input_extractor(input_data)
if args in self._distinct_data_views[i]:
self._distinct_data_views[i][args] -= 1
if self._distinct_data_views[i][args] == 0:
del self._distinct_data_views[i][args]
if self._filter_args[i] >= 0 and not input_data[self._filter_args[i]]:
continue
input_extractor = self._input_extractors[i]
args = input_extractor(input_data)
if self._distinct_indexes[i] >= 0 and \
args in self._distinct_data_views[self._distinct_indexes[i]]:
continue
self._udfs[i].retract(self._accumulators[i], *args)

def merge(self, accumulators: Row):
Expand Down
106 changes: 60 additions & 46 deletions flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py

Large diffs are not rendered by default.

28 changes: 25 additions & 3 deletions flink-python/pyflink/fn_execution/operation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ def _next_constant_num():
return constant_value_name, parsed_constant_value


def extract_user_defined_aggregate_function(user_defined_function_proto):
def extract_user_defined_aggregate_function(
current_index,
user_defined_function_proto,
distinct_info_dict: Dict[Tuple[List[str]], Tuple[List[int], List[int]]]):
user_defined_agg = pickle.loads(user_defined_function_proto.payload)
assert isinstance(user_defined_agg, AggregateFunction)
args_str = []
Expand All @@ -193,7 +196,26 @@ def extract_user_defined_aggregate_function(user_defined_function_proto):
# the input argument is a constant value
constant_value_name, parsed_constant_value = \
_parse_constant_value(arg.inputConstant)
for key, value in local_variable_dict.items():
if value == parsed_constant_value:
constant_value_name = key
break
if constant_value_name not in local_variable_dict:
local_variable_dict[constant_value_name] = parsed_constant_value
args_str.append(constant_value_name)
local_variable_dict[constant_value_name] = parsed_constant_value

return user_defined_agg, eval("lambda value : [%s]" % ",".join(args_str), local_variable_dict)
if user_defined_function_proto.distinct:
if tuple(args_str) in distinct_info_dict:
distinct_info_dict[tuple(args_str)][0].append(current_index)
distinct_info_dict[tuple(args_str)][1].append(user_defined_function_proto.filter_arg)
distinct_index = distinct_info_dict[tuple(args_str)][0][0]
else:
distinct_info_dict[tuple(args_str)] = \
([current_index], [user_defined_function_proto.filter_arg])
distinct_index = current_index
else:
distinct_index = -1
return user_defined_agg, \
eval("lambda value : (%s,)" % ",".join(args_str), local_variable_dict), \
user_defined_function_proto.filter_arg, \
distinct_index
31 changes: 27 additions & 4 deletions flink-python/pyflink/fn_execution/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pyflink.fn_execution.beam.beam_coders import DataViewFilterCoder
from pyflink.fn_execution.operation_utils import extract_user_defined_aggregate_function
from pyflink.fn_execution.aggregate import RowKeySelector, SimpleAggsHandleFunction, \
GroupAggFunction, extract_data_view_specs
GroupAggFunction, extract_data_view_specs, DistinctViewDescriptor
from pyflink.metrics.metricbase import GenericMetricGroup
from pyflink.table import FunctionContext, Row
from pyflink.table.functions import Count1AggFunction
Expand Down Expand Up @@ -285,23 +285,46 @@ def open(self):
def generate_func(self, serialized_fn):
user_defined_aggs = []
input_extractors = []
filter_args = []
# stores the indexes of the distinct views which the agg functions used
distinct_indexes = []
# stores the indexes of the functions which share the same distinct view
# and the filter args of them
distinct_info_dict = {}
for i in range(len(serialized_fn.udfs)):
if i != self.index_of_count_star:
user_defined_agg, input_extractor = extract_user_defined_aggregate_function(
serialized_fn.udfs[i])
user_defined_agg, input_extractor, filter_arg, distinct_index = \
extract_user_defined_aggregate_function(
i, serialized_fn.udfs[i], distinct_info_dict)
else:
user_defined_agg = Count1AggFunction()
filter_arg = -1
distinct_index = -1

def dummy_input_extractor(value):
return []
input_extractor = dummy_input_extractor
user_defined_aggs.append(user_defined_agg)
input_extractors.append(input_extractor)
filter_args.append(filter_arg)
distinct_indexes.append(distinct_index)
distinct_view_descriptors = {}
for agg_index_list, filter_arg_list in distinct_info_dict.values():
if -1 in filter_arg_list:
# If there is a non-filter call, we don't need to check filter or not before
# writing the distinct data view.
filter_arg_list = []
# use the agg index of the first function as the key of shared distinct view
distinct_view_descriptors[agg_index_list[0]] = DistinctViewDescriptor(
input_extractors[agg_index_list[0]], filter_arg_list)
aggs_handler_function = SimpleAggsHandleFunction(
user_defined_aggs,
input_extractors,
self.index_of_count_star,
self.data_view_specs)
self.data_view_specs,
filter_args,
distinct_indexes,
distinct_view_descriptors)
key_selector = RowKeySelector(self.grouping)
if len(self.data_view_specs) > 0:
state_value_coder = DataViewFilterCoder(self.data_view_specs)
Expand Down
8 changes: 4 additions & 4 deletions flink-python/pyflink/fn_execution/state_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,21 +557,21 @@ def __init__(self,

def get(self, map_key):
if self._is_empty:
raise KeyError("Map key %s not found!" % map_key)
raise KeyError("Map key %s not found!" % str(map_key))
if map_key in self._write_cache:
exists, value = self._write_cache[map_key]
if exists:
return value
else:
raise KeyError("Map key %s not found!" % map_key)
raise KeyError("Map key %s not found!" % str(map_key))
if self._cleared:
raise KeyError("Map key %s not found!" % map_key)
raise KeyError("Map key %s not found!" % str(map_key))
exists, value = self._map_state_handler.blocking_get(
self._state_key, map_key, self._map_key_coder_impl, self._map_value_coder_impl)
if exists:
return value
else:
raise KeyError("Map key %s not found!" % map_key)
raise KeyError("Map key %s not found!" % str(map_key))

def put(self, map_key, map_value):
self._write_cache[map_key] = (True, map_value)
Expand Down
4 changes: 4 additions & 0 deletions flink-python/pyflink/proto/flink-fn-execution.proto
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ message UserDefinedAggregateFunction {
repeated Input inputs = 2;

repeated DataViewSpec specs = 3;

int32 filter_arg = 4;

bool distinct = 5;
}

// A list of the user-defined aggregate functions to be executed in a group aggregate operation.
Expand Down
65 changes: 63 additions & 2 deletions flink-python/pyflink/table/tests/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,32 @@ def get_result_type(self):

class ConcatAggregateFunction(AggregateFunction):

def get_value(self, accumulator):
str_list = [i for i in accumulator[0]]
str_list.sort()
return accumulator[1].join(str_list)

def create_accumulator(self):
return Row([], '')

def accumulate(self, accumulator, *args):
accumulator[1] = args[1]
accumulator[0].append(args[0])

def retract(self, accumulator, *args):
accumulator[0].remove(args[0])

def get_accumulator_type(self):
return DataTypes.ROW([
DataTypes.FIELD("f0", DataTypes.ARRAY(DataTypes.STRING())),
DataTypes.FIELD("f1", DataTypes.BIGINT())])

def get_result_type(self):
return DataTypes.STRING()


class ListViewConcatAggregateFunction(AggregateFunction):

def get_value(self, accumulator):
return accumulator[1].join(accumulator[0])

Expand All @@ -96,7 +122,9 @@ def retract(self, accumulator, *args):
raise NotImplementedError

def get_accumulator_type(self):
return DataTypes.ROW([DataTypes.FIELD("f0", DataTypes.LIST_VIEW(DataTypes.STRING()))])
return DataTypes.ROW([
DataTypes.FIELD("f0", DataTypes.LIST_VIEW(DataTypes.STRING())),
DataTypes.FIELD("f1", DataTypes.BIGINT())])

def get_result_type(self):
return DataTypes.STRING()
Expand Down Expand Up @@ -241,7 +269,7 @@ def test_using_decorator(self):
self.assertEqual(result_type, DataTypes.INT())

def test_list_view(self):
my_concat = udaf(ConcatAggregateFunction())
my_concat = udaf(ListViewConcatAggregateFunction())
self.t_env.get_config().get_configuration().set_string(
"python.fn-execution.bundle.size", "2")
# trigger the cache eviction in a bundle.
Expand Down Expand Up @@ -358,6 +386,39 @@ def test_map_view_iterate(self):
["Hi,Hi2,Hi3", "1,2,3", "Hi:3,Hi2:2,Hi3:1", 3, "hi"]],
columns=['a', 'b', 'c', 'd', 'e']))

def test_distinct_and_filter(self):
self.t_env.create_temporary_system_function(
"concat",
ConcatAggregateFunction())
t = self.t_env.from_elements(
[(1, 'Hi_', 'hi'),
(1, 'Hi', 'hi'),
(2, 'hello', 'hello'),
(3, 'Hi_', 'hi'),
(3, 'Hi', 'hi'),
(4, 'hello', 'hello'),
(5, 'Hi2_', 'hi'),
(5, 'Hi2', 'hi'),
(6, 'hello2', 'hello'),
(7, 'Hi', 'hi'),
(8, 'hello', 'hello'),
(9, 'Hi2', 'hi'),
(13, 'Hi3', 'hi')], ['a', 'b', 'c'])
self.t_env.create_temporary_view("source", t)
table_with_retract_message = self.t_env.sql_query(
"select LAST_VALUE(b) as b, LAST_VALUE(c) as c from source group by a")
self.t_env.create_temporary_view("retract_table", table_with_retract_message)
result = self.t_env.sql_query(
"select concat(distinct b, '.') as a, "
"concat(distinct b, ',') filter (where c = 'hi') as b, "
"concat(distinct b, ',') filter (where c = 'hello') as c, "
"c as d "
"from retract_table group by c")
assert_frame_equal(result.to_pandas().sort_values(by='a').reset_index(drop=True),
pd.DataFrame([["Hi.Hi2.Hi3", "Hi,Hi2,Hi3", "", "hi"],
["hello.hello2", "", "hello,hello2", "hello"]],
columns=['a', 'b', 'c', 'd']))


if __name__ == '__main__':
import unittest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.fnexecution.v1.FlinkFnApi;
import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.planner.typeutils.DataViewUtils;

Expand Down Expand Up @@ -56,10 +57,12 @@ public static FlinkFnApi.UserDefinedFunction getUserDefinedFunctionProto(PythonF
}

public static FlinkFnApi.UserDefinedAggregateFunction getUserDefinedAggregateFunctionProto(
PythonFunctionInfo pythonFunctionInfo,
PythonAggregateFunctionInfo pythonFunctionInfo,
DataViewUtils.DataViewSpec[] dataViewSpecs) {
FlinkFnApi.UserDefinedAggregateFunction.Builder builder = FlinkFnApi.UserDefinedAggregateFunction.newBuilder();
builder.setPayload(ByteString.copyFrom(pythonFunctionInfo.getPythonFunction().getSerializedPythonFunction()));
builder.setDistinct(pythonFunctionInfo.isDistinct());
builder.setFilterArg(pythonFunctionInfo.getFilterArg());
for (Object input : pythonFunctionInfo.getInputs()) {
FlinkFnApi.Input.Builder inputProto =
FlinkFnApi.Input.newBuilder();
Expand Down
Loading

0 comments on commit 6ddd2c9

Please sign in to comment.