Skip to content

Commit

Permalink
[SPARK-49552][PYTHON] Add DataFrame API support for new 'randstr' and…
Browse files Browse the repository at this point in the history
… 'uniform' SQL functions

### What changes were proposed in this pull request?

In apache#48004 we added new SQL functions `randstr` and `uniform`. This PR adds DataFrame API support for them.

For example, in Scala:

```
sql("create table t(col int not null) using csv")
sql("insert into t values (0)")
val df = sql("select col from t")
df.select(randstr(lit(5), lit(0)).alias("x")).select(length(col("x")))
> 5

df.select(uniform(lit(10), lit(20), lit(0)).alias("x")).selectExpr("x > 5")
> true
```

### Why are the changes needed?

This improves DataFrame parity with the SQL API.

### Does this PR introduce _any_ user-facing change?

Yes, see above.

### How was this patch tested?

This PR adds unit test coverage.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#48143 from dtenedor/dataframes-uniform-randstr.

Authored-by: Daniel Tenedorio <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
dtenedor authored and zhengruifeng committed Sep 25, 2024
1 parent a4fb6cb commit e2d2ab5
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 10 deletions.
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.sql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Mathematical Functions
try_multiply
try_subtract
unhex
uniform
width_bucket


Expand Down Expand Up @@ -189,6 +190,7 @@ String Functions
overlay
position
printf
randstr
regexp_count
regexp_extract
regexp_extract_all
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,22 @@ def unhex(col: "ColumnOrName") -> Column:
unhex.__doc__ = pysparkfuncs.unhex.__doc__


def uniform(
min: Union[Column, int, float],
max: Union[Column, int, float],
seed: Optional[Union[Column, int]] = None,
) -> Column:
if seed is None:
return _invoke_function_over_columns(
"uniform", lit(min), lit(max), lit(random.randint(0, sys.maxsize))
)
else:
return _invoke_function_over_columns("uniform", lit(min), lit(max), lit(seed))


uniform.__doc__ = pysparkfuncs.uniform.__doc__


def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column:
warnings.warn("Deprecated in 3.4, use approx_count_distinct instead.", FutureWarning)
return approx_count_distinct(col, rsd)
Expand Down Expand Up @@ -2581,6 +2597,18 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
regexp_like.__doc__ = pysparkfuncs.regexp_like.__doc__


def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column:
if seed is None:
return _invoke_function_over_columns(
"randstr", lit(length), lit(random.randint(0, sys.maxsize))
)
else:
return _invoke_function_over_columns("randstr", lit(length), lit(seed))


randstr.__doc__ = pysparkfuncs.randstr.__doc__


def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
return _invoke_function_over_columns("regexp_count", str, regexp)

Expand Down
92 changes: 92 additions & 0 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11973,6 +11973,47 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
return _invoke_function_over_columns("regexp_like", str, regexp)


@_try_remote_functions
def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column:
"""Returns a string of the specified length whose characters are chosen uniformly at random from
the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length
must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively).

.. versionadded:: 4.0.0

Parameters
----------
length : :class:`~pyspark.sql.Column` or int
Number of characters in the string to generate.
seed : :class:`~pyspark.sql.Column` or int
Optional random number seed to use.

Returns
-------
:class:`~pyspark.sql.Column`
The generated random string with the specified length.

Examples
--------
>>> spark.createDataFrame([('3',)], ['a']) \\
... .select(randstr(lit(5), lit(0)).alias('result')) \\
... .selectExpr("length(result) > 0").show()
+--------------------+
|(length(result) > 0)|
+--------------------+
| true|
+--------------------+
"""
length = _enum_to_value(length)
length = lit(length)
if seed is None:
return _invoke_function_over_columns("randstr", length)
else:
seed = _enum_to_value(seed)
seed = lit(seed)
return _invoke_function_over_columns("randstr", length, seed)


@_try_remote_functions
def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
r"""Returns a count of the number of times that the Java regex pattern `regexp` is matched
Expand Down Expand Up @@ -12339,6 +12380,57 @@ def unhex(col: "ColumnOrName") -> Column:
return _invoke_function_over_columns("unhex", col)


@_try_remote_functions
def uniform(
min: Union[Column, int, float],
max: Union[Column, int, float],
seed: Optional[Union[Column, int]] = None,
) -> Column:
"""Returns a random value with independent and identically distributed (i.i.d.) values with the
specified range of numbers. The random seed is optional. The provided numbers specifying the
minimum and maximum values of the range must be constant. If both of these numbers are integers,
then the result will also be an integer. Otherwise if one or both of these are floating-point
numbers, then the result will also be a floating-point number.

.. versionadded:: 4.0.0

Parameters
----------
min : :class:`~pyspark.sql.Column`, int, or float
Minimum value in the range.
max : :class:`~pyspark.sql.Column`, int, or float
Maximum value in the range.
seed : :class:`~pyspark.sql.Column` or int
Optional random number seed to use.

Returns
-------
:class:`~pyspark.sql.Column`
The generated random number within the specified range.

Examples
--------
>>> spark.createDataFrame([('3',)], ['a']) \\
... .select(uniform(lit(0), lit(10), lit(0)).alias('result')) \\
... .selectExpr("result < 15").show()
+-------------+
|(result < 15)|
+-------------+
| true|
+-------------+
"""
min = _enum_to_value(min)
min = lit(min)
max = _enum_to_value(max)
max = lit(max)
if seed is None:
return _invoke_function_over_columns("uniform", min, max)
else:
seed = _enum_to_value(seed)
seed = lit(seed)
return _invoke_function_over_columns("uniform", min, max, seed)


@_try_remote_functions
def length(col: "ColumnOrName") -> Column:
"""Computes the character length of string data or number of bytes of binary data.
Expand Down
21 changes: 20 additions & 1 deletion python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyspark.sql import Row, Window, functions as F, types
from pyspark.sql.avro.functions import from_avro, to_avro
from pyspark.sql.column import Column
from pyspark.sql.functions.builtin import nullifzero, zeroifnull
from pyspark.sql.functions.builtin import nullifzero, randstr, uniform, zeroifnull
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
from pyspark.testing.utils import have_numpy

Expand Down Expand Up @@ -1610,6 +1610,25 @@ def test_nullifzero_zeroifnull(self):
result = df.select(zeroifnull(df.a).alias("r")).collect()
self.assertEqual([Row(r=0), Row(r=1)], result)

def test_randstr_uniform(self):
df = self.spark.createDataFrame([(0,)], ["a"])
result = df.select(randstr(F.lit(5), F.lit(0)).alias("x")).selectExpr("length(x)").collect()
self.assertEqual([Row(5)], result)
# The random seed is optional.
result = df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)").collect()
self.assertEqual([Row(5)], result)

df = self.spark.createDataFrame([(0,)], ["a"])
result = (
df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x"))
.selectExpr("x > 5")
.collect()
)
self.assertEqual([Row(True)], result)
# The random seed is optional.
result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5").collect()
self.assertEqual([Row(True)], result)


class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
pass
Expand Down
45 changes: 45 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1896,6 +1896,26 @@ object functions {
*/
def randn(): Column = randn(SparkClassUtils.random.nextLong)

/**
* Returns a string of the specified length whose characters are chosen uniformly at random from
* the following pool of characters: 0-9, a-z, A-Z. The string length must be a constant
* two-byte or four-byte integer (SMALLINT or INT, respectively).
*
* @group string_funcs
* @since 4.0.0
*/
def randstr(length: Column): Column = Column.fn("randstr", length)

/**
* Returns a string of the specified length whose characters are chosen uniformly at random from
* the following pool of characters: 0-9, a-z, A-Z, with the chosen random seed. The string
* length must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively).
*
* @group string_funcs
* @since 4.0.0
*/
def randstr(length: Column, seed: Column): Column = Column.fn("randstr", length, seed)

/**
* Partition ID.
*
Expand Down Expand Up @@ -3740,6 +3760,31 @@ object functions {
*/
def stack(cols: Column*): Column = Column.fn("stack", cols: _*)

/**
* Returns a random value with independent and identically distributed (i.i.d.) values with the
* specified range of numbers. The provided numbers specifying the minimum and maximum values of
* the range must be constant. If both of these numbers are integers, then the result will also
* be an integer. Otherwise if one or both of these are floating-point numbers, then the result
* will also be a floating-point number.
*
* @group math_funcs
* @since 4.0.0
*/
def uniform(min: Column, max: Column): Column = Column.fn("uniform", min, max)

/**
* Returns a random value with independent and identically distributed (i.i.d.) values with the
* specified range of numbers, with the chosen random seed. The provided numbers specifying the
* minimum and maximum values of the range must be constant. If both of these numbers are
* integers, then the result will also be an integer. Otherwise if one or both of these are
* floating-point numbers, then the result will also be a floating-point number.
*
* @group math_funcs
* @since 4.0.0
*/
def uniform(min: Column, max: Column, seed: Column): Column =
Column.fn("uniform", min, max, seed)

/**
* Returns a random value with independent and identically distributed (i.i.d.) uniformly
* distributed values in [0, 1).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,18 @@ object Randn {
""",
since = "4.0.0",
group = "math_funcs")
case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
case class Uniform(min: Expression, max: Expression, seedExpression: Expression, hideSeed: Boolean)
extends RuntimeReplaceable with TernaryLike[Expression] with RDG {
def this(min: Expression, max: Expression) = this(min, max, UnresolvedSeed)
def this(min: Expression, max: Expression) =
this(min, max, UnresolvedSeed, hideSeed = true)
def this(min: Expression, max: Expression, seedExpression: Expression) =
this(min, max, seedExpression, hideSeed = false)

final override lazy val deterministic: Boolean = false
override val nodePatterns: Seq[TreePattern] =
Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED)

override val dataType: DataType = {
override def dataType: DataType = {
val first = min.dataType
val second = max.dataType
(min.dataType, max.dataType) match {
Expand All @@ -240,6 +243,10 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
case _ => false
}

override def sql: String = {
s"uniform(${min.sql}, ${max.sql}${if (hideSeed) "" else s", ${seedExpression.sql}"})"
}

override def checkInputDataTypes(): TypeCheckResult = {
var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
def requiredType = "integer or floating-point"
Expand Down Expand Up @@ -277,11 +284,11 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
override def third: Expression = seedExpression

override def withNewSeed(newSeed: Long): Expression =
Uniform(min, max, Literal(newSeed, LongType))
Uniform(min, max, Literal(newSeed, LongType), hideSeed)

override def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
Uniform(newFirst, newSecond, newThird)
Uniform(newFirst, newSecond, newThird, hideSeed)

override def replacement: Expression = {
if (Seq(min, max, seedExpression).exists(_.dataType == NullType)) {
Expand All @@ -300,6 +307,13 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
}
}

object Uniform {
def apply(min: Expression, max: Expression): Uniform =
Uniform(min, max, UnresolvedSeed, hideSeed = true)
def apply(min: Expression, max: Expression, seedExpression: Expression): Uniform =
Uniform(min, max, seedExpression, hideSeed = false)
}

@ExpressionDescription(
usage = """
_FUNC_(length[, seed]) - Returns a string of the specified length whose characters are chosen
Expand All @@ -315,9 +329,13 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
""",
since = "4.0.0",
group = "string_funcs")
case class RandStr(length: Expression, override val seedExpression: Expression)
case class RandStr(
length: Expression, override val seedExpression: Expression, hideSeed: Boolean)
extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic {
def this(length: Expression) = this(length, UnresolvedSeed)
def this(length: Expression) =
this(length, UnresolvedSeed, hideSeed = true)
def this(length: Expression, seedExpression: Expression) =
this(length, seedExpression, hideSeed = false)

override def nullable: Boolean = false
override def dataType: DataType = StringType
Expand All @@ -339,9 +357,14 @@ case class RandStr(length: Expression, override val seedExpression: Expression)
rng = new XORShiftRandom(seed + partitionIndex)
}

override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType))
override def withNewSeed(newSeed: Long): Expression =
RandStr(length, Literal(newSeed, LongType), hideSeed)
override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression =
RandStr(newFirst, newSecond)
RandStr(newFirst, newSecond, hideSeed)

override def sql: String = {
s"randstr(${length.sql}${if (hideSeed) "" else s", ${seedExpression.sql}"})"
}

override def checkInputDataTypes(): TypeCheckResult = {
var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
Expand Down Expand Up @@ -422,3 +445,11 @@ case class RandStr(length: Expression, override val seedExpression: Expression)
isNull = FalseLiteral)
}
}

object RandStr {
def apply(length: Expression): RandStr =
RandStr(length, UnresolvedSeed, hideSeed = true)
def apply(length: Expression, seedExpression: Expression): RandStr =
RandStr(length, seedExpression, hideSeed = false)
}

Loading

0 comments on commit e2d2ab5

Please sign in to comment.