Skip to content

Commit

Permalink
[FLINK-13594][python] Improve the 'from_element' method of flink pyth…
Browse files Browse the repository at this point in the history
…on api to apply to blink planner.

This closes apache#9370
  • Loading branch information
WeiZhong94 authored and hequn8128 committed Aug 8, 2019
1 parent 268da6a commit 2979a31
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 156 deletions.
38 changes: 22 additions & 16 deletions flink-python/pyflink/table/table_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,12 +668,21 @@ def _from_elements(self, elements, schema):
serializer.dump_to_stream(elements, temp_file)
finally:
temp_file.close()
return self._from_file(temp_file.name, schema)
row_type_info = _to_java_type(schema)
execution_config = self._get_execution_config(temp_file.name, schema)
gateway = get_gateway()
j_objs = gateway.jvm.PythonBridgeUtils.readPythonObjects(temp_file.name, True)
j_input_format = gateway.jvm.PythonTableUtils.getInputFormat(
j_objs, row_type_info, execution_config)
j_table_source = gateway.jvm.PythonInputFormatTableSource(
j_input_format, row_type_info)

return Table(self._j_tenv.fromTableSource(j_table_source))
finally:
os.unlink(temp_file.name)

@abstractmethod
def _from_file(self, filename, schema):
def _get_execution_config(self, filename, schema):
pass


Expand All @@ -683,12 +692,8 @@ def __init__(self, j_tenv):
self._j_tenv = j_tenv
super(StreamTableEnvironment, self).__init__(j_tenv)

def _from_file(self, filename, schema):
gateway = get_gateway()
jds = gateway.jvm.PythonBridgeUtils.createDataStreamFromFile(
self._j_tenv.execEnv(), filename, True)
return Table(gateway.jvm.PythonTableUtils.fromDataStream(
self._j_tenv, jds, _to_java_type(schema)))
def _get_execution_config(self, filename, schema):
return self._j_tenv.execEnv().getConfig()

def get_config(self):
"""
Expand Down Expand Up @@ -796,18 +801,19 @@ def __init__(self, j_tenv):
self._j_tenv = j_tenv
super(BatchTableEnvironment, self).__init__(j_tenv)

def _from_file(self, filename, schema):
def _get_execution_config(self, filename, schema):
gateway = get_gateway()
blink_t_env_class = get_java_class(
gateway.jvm.org.apache.flink.table.api.internal.TableEnvironmentImpl)
if blink_t_env_class == self._j_tenv.getClass():
raise NotImplementedError("The operation 'from_elements' in batch mode is currently "
"not supported when using blink planner.")
is_blink = (blink_t_env_class == self._j_tenv.getClass())
if is_blink:
# we can not get ExecutionConfig object from the TableEnvironmentImpl
# for the moment, just create a new ExecutionConfig.
execution_config = gateway.jvm.org.apache.flink.api.common.ExecutionConfig()
else:
jds = gateway.jvm.PythonBridgeUtils.createDataSetFromFile(
self._j_tenv.execEnv(), filename, True)
return Table(gateway.jvm.PythonTableUtils.fromDataSet(
self._j_tenv, jds, _to_java_type(schema)))
execution_config = self._j_tenv.execEnv().getConfig()

return execution_config

def get_config(self):
"""
Expand Down
52 changes: 49 additions & 3 deletions flink-python/pyflink/table/tests/test_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import datetime
from decimal import Decimal

from pyflink.table import DataTypes, Row
from pyflink.table import DataTypes, Row, BatchTableEnvironment, EnvironmentSettings
from pyflink.table.tests.test_types import ExamplePoint, PythonOnlyPoint, ExamplePointUDT, \
PythonOnlyUDT
from pyflink.testing import source_sink_utils
Expand Down Expand Up @@ -97,14 +97,60 @@ def test_from_element(self):
PythonOnlyPoint(3.0, 4.0))],
schema)
t.insert_into("Results")
self.t_env.execute("test")
t_env.execute("test")
actual = source_sink_utils.results()

expected = ['1,1.0,hi,hello,1970-01-02,01:00:00,1970-01-02 00:00:00.0,'
'1970-01-02 00:00:00.0,86400000010,[1.0, null],[1.0, 2.0],[abc],[1970-01-02],'
'1970-01-02 00:00:00.0,86400000,[1.0, null],[1.0, 2.0],[abc],[1970-01-02],'
'1,1,2.0,{key=1.0},[65, 66, 67, 68],[1.0, 2.0],[3.0, 4.0]']
self.assert_equals(actual, expected)

def test_blink_from_element(self):
t_env = BatchTableEnvironment.create(environment_settings=EnvironmentSettings
.new_instance().use_blink_planner()
.in_batch_mode().build())
field_names = ["a", "b", "c", "d", "e", "f", "g", "h",
"i", "j", "k", "l", "m", "n", "o", "p", "q", "r"]
field_types = [DataTypes.BIGINT(), DataTypes.DOUBLE(), DataTypes.STRING(),
DataTypes.STRING(), DataTypes.DATE(),
DataTypes.TIME(),
DataTypes.TIMESTAMP(),
DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(),
DataTypes.INTERVAL(DataTypes.DAY(), DataTypes.SECOND()),
DataTypes.ARRAY(DataTypes.DOUBLE()),
DataTypes.ARRAY(DataTypes.DOUBLE(False)),
DataTypes.ARRAY(DataTypes.STRING()),
DataTypes.ARRAY(DataTypes.DATE()),
DataTypes.DECIMAL(10, 0),
DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT()),
DataTypes.FIELD("b", DataTypes.DOUBLE())]),
DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE()),
DataTypes.BYTES(),
PythonOnlyUDT()]
schema = DataTypes.ROW(
list(map(lambda field_name, field_type: DataTypes.FIELD(field_name, field_type),
field_names,
field_types)))
table_sink = source_sink_utils.TestAppendSink(field_names, field_types)
t_env.register_table_sink("Results", table_sink)
t = t_env.from_elements(
[(1, 1.0, "hi", "hello", datetime.date(1970, 1, 2), datetime.time(1, 0, 0),
datetime.datetime(1970, 1, 2, 0, 0), datetime.datetime(1970, 1, 2, 0, 0),
datetime.timedelta(days=1, microseconds=10),
[1.0, None], array.array("d", [1.0, 2.0]),
["abc"], [datetime.date(1970, 1, 2)], Decimal(1), Row("a", "b")(1, 2.0),
{"key": 1.0}, bytearray(b'ABCD'),
PythonOnlyPoint(3.0, 4.0))],
schema)
t.insert_into("Results")
t_env.execute("test")
actual = source_sink_utils.results()

expected = ['1,1.0,hi,hello,1970-01-02,01:00:00,1970-01-02 00:00:00.0,'
'1970-01-02 00:00:00.0,86400000,[1.0, null],[1.0, 2.0],[abc],[1970-01-02],'
'1.000000000000000000,1,2.0,{key=1.0},[65, 66, 67, 68],[3.0, 4.0]']
self.assert_equals(actual, expected)


if __name__ == '__main__':
import unittest
Expand Down
58 changes: 56 additions & 2 deletions flink-python/pyflink/table/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import datetime
import pickle
import sys
import tempfile
import unittest

from pyflink.serializers import BatchedSerializer, PickleSerializer

from pyflink.java_gateway import get_gateway
from pyflink.table.types import (_infer_schema_from_data, _infer_type,
_array_signed_int_typecode_ctype_mappings,
Expand Down Expand Up @@ -825,10 +828,26 @@ def test_atomic_type_with_data_type_with_parameters(self):
DataTypes.DECIMAL(20, 10, False)]
self.assertEqual(converted_python_types, expected)

# Legacy type tests
Types = gateway.jvm.org.apache.flink.table.api.Types
BlinkBigDecimalTypeInfo = \
gateway.jvm.org.apache.flink.table.runtime.typeutils.BigDecimalTypeInfo

java_types = [Types.STRING(),
Types.DECIMAL(),
BlinkBigDecimalTypeInfo(12, 5)]

converted_python_types = [_from_java_type(item) for item in java_types]

expected = [DataTypes.VARCHAR(2147483647),
DataTypes.DECIMAL(10, 0),
DataTypes.DECIMAL(12, 5)]
self.assertEqual(converted_python_types, expected)

def test_array_type(self):
# nullable/not_null flag will be lost during the conversion.
test_types = [DataTypes.ARRAY(DataTypes.BIGINT()),
# array type with not null basic data type means primitive array
DataTypes.ARRAY(DataTypes.BIGINT().not_null()),
DataTypes.ARRAY(DataTypes.BIGINT()),
DataTypes.ARRAY(DataTypes.STRING()),
DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.BIGINT())),
DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))]
Expand Down Expand Up @@ -879,6 +898,41 @@ def test_row_type(self):
self.assertEqual(test_types, converted_python_types)


class DataSerializerTests(unittest.TestCase):

def test_java_pickle_deserializer(self):
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.mkdtemp())
serializer = PickleSerializer()
data = [(1, 2), (3, 4), (5, 6), (7, 8)]

try:
serializer.dump_to_stream(data, temp_file)
finally:
temp_file.close()

gateway = get_gateway()
result = [tuple(int_pair) for int_pair in
list(gateway.jvm.PythonBridgeUtils.readPythonObjects(temp_file.name, False))]

self.assertEqual(result, [(1, 2), (3, 4), (5, 6), (7, 8)])

def test_java_batch_deserializer(self):
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.mkdtemp())
serializer = BatchedSerializer(PickleSerializer(), 2)
data = [(1, 2), (3, 4), (5, 6), (7, 8)]

try:
serializer.dump_to_stream(data, temp_file)
finally:
temp_file.close()

gateway = get_gateway()
result = [tuple(int_pair) for int_pair in
list(gateway.jvm.PythonBridgeUtils.readPythonObjects(temp_file.name, True))]

self.assertEqual(result, [(1, 2), (3, 4), (5, 6), (7, 8)])


if __name__ == "__main__":
try:
import xmlrunner
Expand Down
19 changes: 8 additions & 11 deletions flink-python/pyflink/table/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,17 +1665,7 @@ def _to_java_type(data_type):

# ArrayType
elif isinstance(data_type, ArrayType):
if type(data_type.element_type) in _primitive_array_element_types:
if data_type.element_type._nullable is False:
return Types.PRIMITIVE_ARRAY(_to_java_type(data_type.element_type))
else:
return Types.OBJECT_ARRAY(_to_java_type(data_type.element_type))
elif isinstance(data_type.element_type, VarCharType) or isinstance(
data_type.element_type, CharType):
return gateway.jvm.org.apache.flink.api.common.typeinfo.\
BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO
else:
return Types.OBJECT_ARRAY(_to_java_type(data_type.element_type))
return Types.OBJECT_ARRAY(_to_java_type(data_type.element_type))

# MapType
elif isinstance(data_type, MapType):
Expand Down Expand Up @@ -1783,8 +1773,15 @@ def _from_java_type(j_data_type):
type_info = logical_type.getTypeInformation()
BasicArrayTypeInfo = gateway.jvm.org.apache.flink.api.common.typeinfo.\
BasicArrayTypeInfo
BasicTypeInfo = gateway.jvm.org.apache.flink.api.common.typeinfo.BasicTypeInfo
if type_info == BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO:
data_type = DataTypes.ARRAY(DataTypes.STRING())
elif type_info == BasicTypeInfo.BIG_DEC_TYPE_INFO:
data_type = DataTypes.DECIMAL(10, 0)
elif type_info.getClass() == \
get_java_class(gateway.jvm.org.apache.flink.table.runtime.typeutils
.BigDecimalTypeInfo):
data_type = DataTypes.DECIMAL(type_info.precision(), type_info.scale())
else:
raise TypeError("Unsupported type: %s, it is recognized as a legacy type."
% type_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,8 @@

package org.apache.flink.api.common.python;

import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.python.pickle.ArrayConstructor;
import org.apache.flink.api.common.python.pickle.ByteArrayConstructor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.util.Collector;

import net.razorvine.pickle.Unpickler;

Expand All @@ -39,38 +31,46 @@
import java.util.List;

/**
* Utility class that contains helper methods to create a DataStream/DataSet from
* Utility class that contains helper methods to create a TableSource from
* a file which contains Python objects.
*/
public final class PythonBridgeUtils {

/**
* Creates a DataStream from a file which contains serialized python objects.
*/
public static DataStream<Object[]> createDataStreamFromFile(
final StreamExecutionEnvironment streamExecutionEnvironment,
final String fileName,
final boolean batched) throws IOException {
return streamExecutionEnvironment
.fromCollection(readPythonObjects(fileName))
.flatMap(new PythonFlatMapFunction(batched))
.returns(Types.GENERIC(Object[].class));
private static Object[] getObjectArrayFromUnpickledData(Object input) {
if (input.getClass().isArray()) {
return (Object[]) input;
} else {
return ((ArrayList<Object>) input).toArray(new Object[0]);
}
}

/**
* Creates a DataSet from a file which contains serialized python objects.
*/
public static DataSet<Object[]> createDataSetFromFile(
final ExecutionEnvironment executionEnvironment,
final String fileName,
final boolean batched) throws IOException {
return executionEnvironment
.fromCollection(readPythonObjects(fileName))
.flatMap(new PythonFlatMapFunction(batched))
.returns(Types.GENERIC(Object[].class));
public static List<Object[]> readPythonObjects(String fileName, boolean batched)
throws IOException {
List<byte[]> data = readPickledBytes(fileName);
Unpickler unpickle = new Unpickler();
initialize();
List<Object[]> unpickledData = new ArrayList<>();
for (byte[] pickledData: data) {
Object obj = unpickle.loads(pickledData);
if (batched) {
if (obj instanceof Object[]) {
Object[] arrayObj = (Object[]) obj;
for (Object o : arrayObj) {
unpickledData.add(getObjectArrayFromUnpickledData(o));
}
} else {
for (Object o : (ArrayList<Object>) obj) {
unpickledData.add(getObjectArrayFromUnpickledData(o));
}
}
} else {
unpickledData.add(getObjectArrayFromUnpickledData(obj));
}
}
return unpickledData;
}

private static List<byte[]> readPythonObjects(final String fileName) throws IOException {
private static List<byte[]> readPickledBytes(final String fileName) throws IOException {
List<byte[]> objs = new LinkedList<>();
try (DataInputStream din = new DataInputStream(new FileInputStream(fileName))) {
try {
Expand All @@ -87,50 +87,6 @@ private static List<byte[]> readPythonObjects(final String fileName) throws IOEx
return objs;
}

private static final class PythonFlatMapFunction extends RichFlatMapFunction<byte[], Object[]> {

private static final long serialVersionUID = 1L;

private final boolean batched;
private transient Unpickler unpickle;

PythonFlatMapFunction(boolean batched) {
this.batched = batched;
initialize();
}

@Override
public void open(Configuration parameters) {
this.unpickle = new Unpickler();
}

@Override
public void flatMap(byte[] value, Collector<Object[]> out) throws Exception {
Object obj = unpickle.loads(value);
if (batched) {
if (obj instanceof Object[]) {
for (int i = 0; i < ((Object[]) obj).length; i++) {
collect(out, ((Object[]) obj)[i]);
}
} else {
for (Object o : (ArrayList<Object>) obj) {
collect(out, o);
}
}
} else {
collect(out, obj);
}
}

private void collect(Collector<Object[]> out, Object obj) {
if (obj.getClass().isArray()) {
out.collect((Object[]) obj);
} else {
out.collect(((ArrayList<Object>) obj).toArray(new Object[0]));
}
}
}

private static boolean initialized = false;
private static void initialize() {
synchronized (PythonBridgeUtils.class) {
Expand Down
Loading

0 comments on commit 2979a31

Please sign in to comment.