Skip to content

Commit

Permalink
[SPARK-42725][CONNECT][PYTHON] Make LiteralExpression support array p…
Browse files Browse the repository at this point in the history
…arams

### What changes were proposed in this pull request?
Make LiteralExpression support array

### Why are the changes needed?
MLIib requires literal to carry the array params, like  `IntArrayParam`, `DoubleArrayArrayParam`.

Note that this PR doesn't affect existing `functions.lit` method which apply unresolved `CreateArray` expression to support array input.

### Does this PR introduce _any_ user-facing change?
No, dev-only

### How was this patch tested?
added UT

Closes apache#40349 from zhengruifeng/connect_py_ml_lit.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Mar 10, 2023
1 parent b36966f commit d6d0fc7
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ object LiteralProtoConverter {
def arrayBuilder(array: Array[_]) = {
val ab = builder.getArrayBuilder
.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType)))
array.foreach(x => ab.addElement(toLiteralProto(x)))
array.foreach(x => ab.addElements(toLiteralProto(x)))
ab
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ message Expression {
}

message Array {
DataType elementType = 1;
repeated Literal element = 2;
DataType element_type = 1;
repeated Literal elements = 2;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ object LiteralValueProtoConverter {
expressions.Literal.create(
toArrayData(lit.getArray),
ArrayType(DataTypeProtoConverter.toCatalystType(lit.getArray.getElementType)))

case _ =>
throw InvalidPlanInput(
s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
Expand Down Expand Up @@ -143,7 +144,7 @@ object LiteralValueProtoConverter {
def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit
tag: ClassTag[T]): Array[T] = {
val builder = mutable.ArrayBuilder.make[T]
val elementList = array.getElementList
val elementList = array.getElementsList
builder.sizeHint(elementList.size())
val iter = elementList.iterator()
while (iter.hasNext) {
Expand Down
38 changes: 30 additions & 8 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from pyspark.sql.types import (
_from_numpy_type,
DateType,
ArrayType,
NullType,
BooleanType,
BinaryType,
Expand Down Expand Up @@ -193,6 +194,7 @@ def __init__(self, value: Any, dataType: DataType) -> None:
TimestampType,
TimestampNTZType,
DayTimeIntervalType,
ArrayType,
),
)

Expand Down Expand Up @@ -247,6 +249,8 @@ def __init__(self, value: Any, dataType: DataType) -> None:
assert isinstance(value, datetime.timedelta)
value = DayTimeIntervalType().toInternal(value)
assert value is not None
elif isinstance(dataType, ArrayType):
assert isinstance(value, list)
else:
raise TypeError(f"Unsupported Data Type {dataType}")

Expand Down Expand Up @@ -280,14 +284,25 @@ def _infer_type(cls, value: Any) -> DataType:
return DateType()
elif isinstance(value, datetime.timedelta):
return DayTimeIntervalType()
else:
if isinstance(value, np.generic):
dt = _from_numpy_type(value.dtype)
if dt is not None:
return dt
elif isinstance(value, np.bool_):
return BooleanType()
raise TypeError(f"Unsupported Data Type {type(value).__name__}")
elif isinstance(value, np.generic):
dt = _from_numpy_type(value.dtype)
if dt is not None:
return dt
elif isinstance(value, np.bool_):
return BooleanType()
elif isinstance(value, list):
# follow the 'infer_array_from_first_element' strategy in 'sql.types._infer_type'
# right now, it's dedicated for pyspark.ml params like array<...>, array<array<...>>
if len(value) == 0:
raise TypeError("Can not infer Array Type from an empty list")
first = value[0]
if first is None:
raise TypeError(
"Can not infer Array Type from an list with None as the first element"
)
return ArrayType(LiteralExpression._infer_type(first), True)

raise TypeError(f"Unsupported Data Type {type(value).__name__}")

@classmethod
def _from_value(cls, value: Any) -> "LiteralExpression":
Expand Down Expand Up @@ -330,6 +345,13 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr.literal.timestamp_ntz = int(self._value)
elif isinstance(self._dataType, DayTimeIntervalType):
expr.literal.day_time_interval = int(self._value)
elif isinstance(self._dataType, ArrayType):
element_type = self._dataType.elementType
expr.literal.array.element_type.CopyFrom(pyspark_types_to_proto_types(element_type))
for v in self._value:
expr.literal.array.elements.append(
LiteralExpression(v, element_type).to_plan(session).literal
)
else:
raise ValueError(f"Unsupported Data Type {self._dataType}")

Expand Down
66 changes: 33 additions & 33 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -443,29 +443,29 @@ class Expression(google.protobuf.message.Message):
class Array(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

ELEMENTTYPE_FIELD_NUMBER: builtins.int
ELEMENT_FIELD_NUMBER: builtins.int
ELEMENT_TYPE_FIELD_NUMBER: builtins.int
ELEMENTS_FIELD_NUMBER: builtins.int
@property
def elementType(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
@property
def element(
def elements(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
global___Expression.Literal
]: ...
def __init__(
self,
*,
elementType: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
element: collections.abc.Iterable[global___Expression.Literal] | None = ...,
element_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
elements: collections.abc.Iterable[global___Expression.Literal] | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["elementType", b"elementType"]
self, field_name: typing_extensions.Literal["element_type", b"element_type"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"element", b"element", "elementType", b"elementType"
"element_type", b"element_type", "elements", b"elements"
],
) -> None: ...

Expand Down
3 changes: 0 additions & 3 deletions python/pyspark/sql/tests/connect/test_connect_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Row,
StructField,
StructType,
ArrayType,
MapType,
NullType,
DateType,
Expand Down Expand Up @@ -437,7 +436,6 @@ def test_literal_with_unsupported_type(self):
(0.1, DecimalType()),
(datetime.date(2022, 12, 13), TimestampType()),
(datetime.timedelta(1, 2, 3), DateType()),
([1, 2, 3], ArrayType(IntegerType())),
({1: 2}, MapType(IntegerType(), IntegerType())),
(
{"a": "xyz", "b": 1},
Expand Down Expand Up @@ -474,7 +472,6 @@ def test_literal_null(self):
for value, dataType in [
("123", NullType()),
(123, NullType()),
(None, ArrayType(IntegerType())),
(None, MapType(IntegerType(), IntegerType())),
(None, StructType([StructField("a", StringType())])),
]:
Expand Down
42 changes: 42 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import uuid
import datetime
import decimal
import math

from pyspark.testing.connectutils import (
PlanOnlyTestFixture,
Expand All @@ -31,6 +32,7 @@
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import WriteOperation, Read
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.expressions import LiteralExpression
from pyspark.sql.connect.functions import col, lit, max, min, sum
from pyspark.sql.connect.types import pyspark_types_to_proto_types
from pyspark.sql.types import (
Expand Down Expand Up @@ -944,6 +946,46 @@ def test_column_expressions(self):
mod_fun.unresolved_function.arguments[0].unresolved_attribute.unparsed_identifier, "id"
)

def test_literal_expression_with_arrays(self):
l0 = LiteralExpression._from_value(["x", "y", "z"]).to_plan(None).literal
self.assertTrue(l0.array.element_type.HasField("string"))
self.assertEqual(len(l0.array.elements), 3)
self.assertEqual(l0.array.elements[0].string, "x")
self.assertEqual(l0.array.elements[1].string, "y")
self.assertEqual(l0.array.elements[2].string, "z")

l1 = LiteralExpression._from_value([3, -3]).to_plan(None).literal
self.assertTrue(l1.array.element_type.HasField("integer"))
self.assertEqual(len(l1.array.elements), 2)
self.assertEqual(l1.array.elements[0].integer, 3)
self.assertEqual(l1.array.elements[1].integer, -3)

l2 = LiteralExpression._from_value([float("nan"), -3.0, 0.0]).to_plan(None).literal
self.assertTrue(l2.array.element_type.HasField("double"))
self.assertEqual(len(l2.array.elements), 3)
self.assertTrue(math.isnan(l2.array.elements[0].double))
self.assertEqual(l2.array.elements[1].double, -3.0)
self.assertEqual(l2.array.elements[2].double, 0.0)

l3 = LiteralExpression._from_value([[3, 4], [5, 6, 7]]).to_plan(None).literal
self.assertTrue(l3.array.element_type.HasField("array"))
self.assertTrue(l3.array.element_type.array.element_type.HasField("integer"))
self.assertEqual(len(l3.array.elements), 2)
self.assertEqual(len(l3.array.elements[0].array.elements), 2)
self.assertEqual(len(l3.array.elements[1].array.elements), 3)

l4 = (
LiteralExpression._from_value([[float("inf"), 0.4], [0.5, float("nan")], []])
.to_plan(None)
.literal
)
self.assertTrue(l4.array.element_type.HasField("array"))
self.assertTrue(l4.array.element_type.array.element_type.HasField("double"))
self.assertEqual(len(l4.array.elements), 3)
self.assertEqual(len(l4.array.elements[0].array.elements), 2)
self.assertEqual(len(l4.array.elements[1].array.elements), 2)
self.assertEqual(len(l4.array.elements[2].array.elements), 0)


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_plan import * # noqa: F401
Expand Down

0 comments on commit d6d0fc7

Please sign in to comment.