Skip to content

Commit

Permalink
[FLINK-30168][python] Fix DataStream.execute_and_collect to support N…
Browse files Browse the repository at this point in the history
…one data and ObjectArray

This closes apache#21664.
  • Loading branch information
dianfu committed Jan 16, 2023
1 parent addca4e commit 4675773
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
22 changes: 20 additions & 2 deletions flink-python/pyflink/datastream/tests/test_data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,8 +1231,10 @@ def tearDown(self) -> None:
self.test_sink.clear()

def assert_equals_sorted(self, expected, actual):
expected.sort()
actual.sort()
# otherwise, it may thrown exceptions such as the following:
# TypeError: '<' not supported between instances of 'NoneType' and 'str'
expected.sort(key=lambda x: str(x))
actual.sort(key=lambda x: str(x))
self.assertEqual(expected, actual)

def test_data_stream_name(self):
Expand Down Expand Up @@ -1496,6 +1498,22 @@ def test_execute_and_collect(self):
actual = [r for r in results]
self.assert_equals_sorted(expected, actual)

test_data = [
(["test", "test"], [0.0, 0.0]),
([None, ], [0.0, 0.0])
]

ds = self.env.from_collection(
test_data,
type_info=Types.TUPLE(
[Types.OBJECT_ARRAY(Types.STRING()), Types.OBJECT_ARRAY(Types.DOUBLE())]
)
)
expected = test_data
with ds.execute_and_collect() as results:
actual = [result for result in results]
self.assert_equals_sorted(expected, actual)

def test_function_with_error(self):
ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 1)],
type_info=Types.ROW([Types.STRING(), Types.INT()]))
Expand Down
5 changes: 4 additions & 1 deletion flink-python/pyflink/datastream/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def convert_to_python_obj(data, type_info):
pickle_bytes = gateway.jvm.PythonBridgeUtils. \
getPickledBytesFromJavaObject(data, type_info.get_java_type_info())
if isinstance(type_info, RowTypeInfo) or isinstance(type_info, TupleTypeInfo):
field_data = zip(list(pickle_bytes[1:]), type_info.get_field_types())
if isinstance(type_info, RowTypeInfo):
field_data = zip(list(pickle_bytes[1:]), type_info.get_field_types())
else:
field_data = zip(pickle_bytes, type_info.get_field_types())
fields = []
for data, field_type in field_data:
if len(data) == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.typeutils.ListTypeInfo;
import org.apache.flink.api.java.typeutils.MapTypeInfo;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
Expand Down Expand Up @@ -241,7 +242,7 @@ public static Object getPickledBytesFromJavaObject(Object obj, TypeInformation<?
Pickler pickler = new Pickler();
initialize();
if (obj == null) {
return new byte[0];
return pickler.dumps(null);
} else {
if (dataType instanceof SqlTimeTypeInfo) {
SqlTimeTypeInfo<?> sqlTimeTypeInfo =
Expand Down Expand Up @@ -270,15 +271,19 @@ public static Object getPickledBytesFromJavaObject(Object obj, TypeInformation<?
}
return fieldBytes;
} else if (dataType instanceof BasicArrayTypeInfo
|| dataType instanceof PrimitiveArrayTypeInfo) {
|| dataType instanceof PrimitiveArrayTypeInfo
|| dataType instanceof ObjectArrayTypeInfo) {
Object[] objects;
TypeInformation<?> elementType;
if (dataType instanceof BasicArrayTypeInfo) {
objects = (Object[]) obj;
elementType = ((BasicArrayTypeInfo<?, ?>) dataType).getComponentInfo();
} else {
} else if (dataType instanceof PrimitiveArrayTypeInfo) {
objects = primitiveArrayConverter(obj, dataType);
elementType = ((PrimitiveArrayTypeInfo<?>) dataType).getComponentType();
} else {
objects = (Object[]) obj;
elementType = ((ObjectArrayTypeInfo<?, ?>) dataType).getComponentInfo();
}
List<Object> serializedElements = new ArrayList<>(objects.length);
for (Object object : objects) {
Expand Down

0 comments on commit 4675773

Please sign in to comment.