From e2d2ab510632cc1948cb6b4500e9da49036a96bd Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Wed, 25 Sep 2024 10:57:44 +0800 Subject: [PATCH] [SPARK-49552][PYTHON] Add DataFrame API support for new 'randstr' and 'uniform' SQL functions ### What changes were proposed in this pull request? In https://github.com/apache/spark/pull/48004 we added new SQL functions `randstr` and `uniform`. This PR adds DataFrame API support for them. For example, in Scala: ``` sql("create table t(col int not null) using csv") sql("insert into t values (0)") val df = sql("select col from t") df.select(randstr(lit(5), lit(0)).alias("x")).select(length(col("x"))) > 5 df.select(uniform(lit(10), lit(20), lit(0)).alias("x")).selectExpr("x > 5") > true ``` ### Why are the changes needed? This improves DataFrame parity with the SQL API. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds unit test coverage. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48143 from dtenedor/dataframes-uniform-randstr. Authored-by: Daniel Tenedorio Signed-off-by: Ruifeng Zheng --- .../reference/pyspark.sql/functions.rst | 2 + .../pyspark/sql/connect/functions/builtin.py | 28 +++++ python/pyspark/sql/functions/builtin.py | 92 ++++++++++++++++ python/pyspark/sql/tests/test_functions.py | 21 +++- .../org/apache/spark/sql/functions.scala | 45 ++++++++ .../expressions/randomExpressions.scala | 49 +++++++-- .../spark/sql/DataFrameFunctionsSuite.scala | 104 ++++++++++++++++++ 7 files changed, 331 insertions(+), 10 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 4910a5b59273b..6248e71331656 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -148,6 +148,7 @@ Mathematical Functions try_multiply try_subtract unhex + uniform width_bucket @@ -189,6 +190,7 @@ String Functions overlay position printf + randstr regexp_count regexp_extract regexp_extract_all diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 6953230f5b42e..27b12fff3c0ac 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -1007,6 +1007,22 @@ def unhex(col: "ColumnOrName") -> Column: unhex.__doc__ = pysparkfuncs.unhex.__doc__ +def uniform( + min: Union[Column, int, float], + max: Union[Column, int, float], + seed: Optional[Union[Column, int]] = None, +) -> Column: + if seed is None: + return _invoke_function_over_columns( + "uniform", lit(min), lit(max), lit(random.randint(0, sys.maxsize)) + ) + else: + return _invoke_function_over_columns("uniform", lit(min), lit(max), lit(seed)) + + +uniform.__doc__ = pysparkfuncs.uniform.__doc__ + + def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: warnings.warn("Deprecated in 3.4, use approx_count_distinct instead.", FutureWarning) return approx_count_distinct(col, rsd) @@ -2581,6 +2597,18 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp_like.__doc__ = pysparkfuncs.regexp_like.__doc__ +def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column: + if seed is None: + return _invoke_function_over_columns( + "randstr", lit(length), lit(random.randint(0, sys.maxsize)) + ) + else: + return _invoke_function_over_columns("randstr", lit(length), lit(seed)) + + +randstr.__doc__ = pysparkfuncs.randstr.__doc__ + + def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: return _invoke_function_over_columns("regexp_count", str, regexp) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 09a286fe7c94e..4ca39562cb20b 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -11973,6 +11973,47 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: return _invoke_function_over_columns("regexp_like", str, regexp) +@_try_remote_functions +def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column: + """Returns a string of the specified length whose characters are chosen uniformly at random from + the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length + must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively). + + .. versionadded:: 4.0.0 + + Parameters + ---------- + length : :class:`~pyspark.sql.Column` or int + Number of characters in the string to generate. + seed : :class:`~pyspark.sql.Column` or int + Optional random number seed to use. + + Returns + ------- + :class:`~pyspark.sql.Column` + The generated random string with the specified length. + + Examples + -------- + >>> spark.createDataFrame([('3',)], ['a']) \\ + ... .select(randstr(lit(5), lit(0)).alias('result')) \\ + ... .selectExpr("length(result) > 0").show() + +--------------------+ + |(length(result) > 0)| + +--------------------+ + | true| + +--------------------+ + """ + length = _enum_to_value(length) + length = lit(length) + if seed is None: + return _invoke_function_over_columns("randstr", length) + else: + seed = _enum_to_value(seed) + seed = lit(seed) + return _invoke_function_over_columns("randstr", length, seed) + + @_try_remote_functions def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: r"""Returns a count of the number of times that the Java regex pattern `regexp` is matched @@ -12339,6 +12380,57 @@ def unhex(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("unhex", col) +@_try_remote_functions +def uniform( + min: Union[Column, int, float], + max: Union[Column, int, float], + seed: Optional[Union[Column, int]] = None, +) -> Column: + """Returns a random value with independent and identically distributed (i.i.d.) values with the + specified range of numbers. The random seed is optional. The provided numbers specifying the + minimum and maximum values of the range must be constant. If both of these numbers are integers, + then the result will also be an integer. Otherwise if one or both of these are floating-point + numbers, then the result will also be a floating-point number. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + min : :class:`~pyspark.sql.Column`, int, or float + Minimum value in the range. + max : :class:`~pyspark.sql.Column`, int, or float + Maximum value in the range. + seed : :class:`~pyspark.sql.Column` or int + Optional random number seed to use. + + Returns + ------- + :class:`~pyspark.sql.Column` + The generated random number within the specified range. + + Examples + -------- + >>> spark.createDataFrame([('3',)], ['a']) \\ + ... .select(uniform(lit(0), lit(10), lit(0)).alias('result')) \\ + ... .selectExpr("result < 15").show() + +-------------+ + |(result < 15)| + +-------------+ + | true| + +-------------+ + """ + min = _enum_to_value(min) + min = lit(min) + max = _enum_to_value(max) + max = lit(max) + if seed is None: + return _invoke_function_over_columns("uniform", min, max) + else: + seed = _enum_to_value(seed) + seed = lit(seed) + return _invoke_function_over_columns("uniform", min, max, seed) + + @_try_remote_functions def length(col: "ColumnOrName") -> Column: """Computes the character length of string data or number of bytes of binary data. diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index a0ab9bc9c7d40..a51156e895c62 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -29,7 +29,7 @@ from pyspark.sql import Row, Window, functions as F, types from pyspark.sql.avro.functions import from_avro, to_avro from pyspark.sql.column import Column -from pyspark.sql.functions.builtin import nullifzero, zeroifnull +from pyspark.sql.functions.builtin import nullifzero, randstr, uniform, zeroifnull from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils from pyspark.testing.utils import have_numpy @@ -1610,6 +1610,25 @@ def test_nullifzero_zeroifnull(self): result = df.select(zeroifnull(df.a).alias("r")).collect() self.assertEqual([Row(r=0), Row(r=1)], result) + def test_randstr_uniform(self): + df = self.spark.createDataFrame([(0,)], ["a"]) + result = df.select(randstr(F.lit(5), F.lit(0)).alias("x")).selectExpr("length(x)").collect() + self.assertEqual([Row(5)], result) + # The random seed is optional. + result = df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)").collect() + self.assertEqual([Row(5)], result) + + df = self.spark.createDataFrame([(0,)], ["a"]) + result = ( + df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x")) + .selectExpr("x > 5") + .collect() + ) + self.assertEqual([Row(True)], result) + # The random seed is optional. + result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5").collect() + self.assertEqual([Row(True)], result) + class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin): pass diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index ab69789c75f50..93bff22621057 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1896,6 +1896,26 @@ object functions { */ def randn(): Column = randn(SparkClassUtils.random.nextLong) + /** + * Returns a string of the specified length whose characters are chosen uniformly at random from + * the following pool of characters: 0-9, a-z, A-Z. The string length must be a constant + * two-byte or four-byte integer (SMALLINT or INT, respectively). + * + * @group string_funcs + * @since 4.0.0 + */ + def randstr(length: Column): Column = Column.fn("randstr", length) + + /** + * Returns a string of the specified length whose characters are chosen uniformly at random from + * the following pool of characters: 0-9, a-z, A-Z, with the chosen random seed. The string + * length must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively). + * + * @group string_funcs + * @since 4.0.0 + */ + def randstr(length: Column, seed: Column): Column = Column.fn("randstr", length, seed) + /** * Partition ID. * @@ -3740,6 +3760,31 @@ object functions { */ def stack(cols: Column*): Column = Column.fn("stack", cols: _*) + /** + * Returns a random value with independent and identically distributed (i.i.d.) values with the + * specified range of numbers. The provided numbers specifying the minimum and maximum values of + * the range must be constant. If both of these numbers are integers, then the result will also + * be an integer. Otherwise if one or both of these are floating-point numbers, then the result + * will also be a floating-point number. + * + * @group math_funcs + * @since 4.0.0 + */ + def uniform(min: Column, max: Column): Column = Column.fn("uniform", min, max) + + /** + * Returns a random value with independent and identically distributed (i.i.d.) values with the + * specified range of numbers, with the chosen random seed. The provided numbers specifying the + * minimum and maximum values of the range must be constant. If both of these numbers are + * integers, then the result will also be an integer. Otherwise if one or both of these are + * floating-point numbers, then the result will also be a floating-point number. + * + * @group math_funcs + * @since 4.0.0 + */ + def uniform(min: Column, max: Column, seed: Column): Column = + Column.fn("uniform", min, max, seed) + /** * Returns a random value with independent and identically distributed (i.i.d.) uniformly * distributed values in [0, 1). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index f329f8346b0de..ada0a73a67958 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -206,15 +206,18 @@ object Randn { """, since = "4.0.0", group = "math_funcs") -case class Uniform(min: Expression, max: Expression, seedExpression: Expression) +case class Uniform(min: Expression, max: Expression, seedExpression: Expression, hideSeed: Boolean) extends RuntimeReplaceable with TernaryLike[Expression] with RDG { - def this(min: Expression, max: Expression) = this(min, max, UnresolvedSeed) + def this(min: Expression, max: Expression) = + this(min, max, UnresolvedSeed, hideSeed = true) + def this(min: Expression, max: Expression, seedExpression: Expression) = + this(min, max, seedExpression, hideSeed = false) final override lazy val deterministic: Boolean = false override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED) - override val dataType: DataType = { + override def dataType: DataType = { val first = min.dataType val second = max.dataType (min.dataType, max.dataType) match { @@ -240,6 +243,10 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) case _ => false } + override def sql: String = { + s"uniform(${min.sql}, ${max.sql}${if (hideSeed) "" else s", ${seedExpression.sql}"})" + } + override def checkInputDataTypes(): TypeCheckResult = { var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess def requiredType = "integer or floating-point" @@ -277,11 +284,11 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) override def third: Expression = seedExpression override def withNewSeed(newSeed: Long): Expression = - Uniform(min, max, Literal(newSeed, LongType)) + Uniform(min, max, Literal(newSeed, LongType), hideSeed) override def withNewChildrenInternal( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = - Uniform(newFirst, newSecond, newThird) + Uniform(newFirst, newSecond, newThird, hideSeed) override def replacement: Expression = { if (Seq(min, max, seedExpression).exists(_.dataType == NullType)) { @@ -300,6 +307,13 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) } } +object Uniform { + def apply(min: Expression, max: Expression): Uniform = + Uniform(min, max, UnresolvedSeed, hideSeed = true) + def apply(min: Expression, max: Expression, seedExpression: Expression): Uniform = + Uniform(min, max, seedExpression, hideSeed = false) +} + @ExpressionDescription( usage = """ _FUNC_(length[, seed]) - Returns a string of the specified length whose characters are chosen @@ -315,9 +329,13 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) """, since = "4.0.0", group = "string_funcs") -case class RandStr(length: Expression, override val seedExpression: Expression) +case class RandStr( + length: Expression, override val seedExpression: Expression, hideSeed: Boolean) extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic { - def this(length: Expression) = this(length, UnresolvedSeed) + def this(length: Expression) = + this(length, UnresolvedSeed, hideSeed = true) + def this(length: Expression, seedExpression: Expression) = + this(length, seedExpression, hideSeed = false) override def nullable: Boolean = false override def dataType: DataType = StringType @@ -339,9 +357,14 @@ case class RandStr(length: Expression, override val seedExpression: Expression) rng = new XORShiftRandom(seed + partitionIndex) } - override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType)) + override def withNewSeed(newSeed: Long): Expression = + RandStr(length, Literal(newSeed, LongType), hideSeed) override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression = - RandStr(newFirst, newSecond) + RandStr(newFirst, newSecond, hideSeed) + + override def sql: String = { + s"randstr(${length.sql}${if (hideSeed) "" else s", ${seedExpression.sql}"})" + } override def checkInputDataTypes(): TypeCheckResult = { var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess @@ -422,3 +445,11 @@ case class RandStr(length: Expression, override val seedExpression: Expression) isNull = FalseLiteral) } } + +object RandStr { + def apply(length: Expression): RandStr = + RandStr(length, UnresolvedSeed, hideSeed = true) + def apply(length: Expression, seedExpression: Expression): RandStr = + RandStr(length, seedExpression, hideSeed = false) +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0842b92e5d53c..016803635ff60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -411,6 +411,110 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(nvl2(col("b"), col("a"), col("c"))), Seq(Row(null))) } + test("randstr function") { + withTable("t") { + sql("create table t(col int not null) using csv") + sql("insert into t values (0)") + val df = sql("select col from t") + checkAnswer( + df.select(randstr(lit(5), lit(0)).alias("x")).select(length(col("x"))), + Seq(Row(5))) + // The random seed is optional. + checkAnswer( + df.select(randstr(lit(5)).alias("x")).select(length(col("x"))), + Seq(Row(5))) + } + // Here we exercise some error cases. + val df = Seq((0)).toDF("a") + var expr = randstr(lit(10), lit("a")) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"randstr(10, a)\"", + "paramIndex" -> "second", + "inputSql" -> "\"a\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "INT or SMALLINT"), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "randstr", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + expr = randstr(col("a"), lit(10)) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "inputName" -> "length", + "inputType" -> "INT or SMALLINT", + "inputExpr" -> "\"a\"", + "sqlExpr" -> "\"randstr(a, 10)\""), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "randstr", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + } + + test("uniform function") { + withTable("t") { + sql("create table t(col int not null) using csv") + sql("insert into t values (0)") + val df = sql("select col from t") + checkAnswer( + df.select(uniform(lit(10), lit(20), lit(0)).alias("x")).selectExpr("x > 5"), + Seq(Row(true))) + // The random seed is optional. + checkAnswer( + df.select(uniform(lit(10), lit(20)).alias("x")).selectExpr("x > 5"), + Seq(Row(true))) + } + // Here we exercise some error cases. + val df = Seq((0)).toDF("a") + var expr = uniform(lit(10), lit("a")) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"uniform(10, a)\"", + "paramIndex" -> "second", + "inputSql" -> "\"a\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "integer or floating-point"), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "uniform", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + expr = uniform(col("a"), lit(10)) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "inputName" -> "min", + "inputType" -> "integer or floating-point", + "inputExpr" -> "\"a\"", + "sqlExpr" -> "\"uniform(a, 10)\""), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "uniform", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + } + test("zeroifnull function") { withTable("t") { // Here we exercise a non-nullable, non-foldable column.