Skip to content

Commit

Permalink
[SPARK-42664][CONNECT] Support bloomFilter function for `DataFrameS…
Browse files Browse the repository at this point in the history
…tatFunctions`

### What changes were proposed in this pull request?
This is pr using `BloomFilterAggregate` to implement `bloomFilter` function for `DataFrameStatFunctions`.

### Why are the changes needed?
Add Spark connect jvm client api coverage.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

- Add new test
- Manually check Scala 2.13

Closes apache#42414 from LuciferYang/SPARK-42664-backup.

Authored-by: yangjie01 <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
  • Loading branch information
LuciferYang authored and hvanhovell committed Aug 15, 2023
1 parent 2ab404f commit b9f1114
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ private static int optimalNumOfHashFunctions(long n, long m) {
* See http://en.wikipedia.org/wiki/Bloom_filter#Probability_of_false_positives for the formula.
*
* @param n expected insertions (must be positive)
* @param p false positive rate (must be 0 < p < 1)
* @param p false positive rate (must be 0 &lt; p &lt; 1)
*/
private static long optimalNumOfBits(long n, double p) {
public static long optimalNumOfBits(long n, double p) {
return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
package org.apache.spark.sql

import java.{lang => jl, util => ju}
import java.io.ByteArrayInputStream

import scala.collection.JavaConverters._

import org.apache.spark.SparkException
import org.apache.spark.connect.proto.{Relation, StatSampleBy}
import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder}
import org.apache.spark.sql.functions.lit
import org.apache.spark.util.sketch.CountMinSketch
import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}

/**
* Statistic functions for `DataFrame`s.
Expand Down Expand Up @@ -584,6 +586,90 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo
}
CountMinSketch.readFrom(ds.head())
}

/**
* Builds a Bloom filter over a specified column.
*
* @param colName
* name of the column over which the filter is built
* @param expectedNumItems
* expected number of items which will be put into the filter.
* @param fpp
* expected false positive probability of the filter.
* @since 3.5.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = {
buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp)
}

/**
* Builds a Bloom filter over a specified column.
*
* @param col
* the column over which the filter is built
* @param expectedNumItems
* expected number of items which will be put into the filter.
* @param fpp
* expected false positive probability of the filter.
* @since 3.5.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = {
buildBloomFilter(col, expectedNumItems, -1L, fpp)
}

/**
* Builds a Bloom filter over a specified column.
*
* @param colName
* name of the column over which the filter is built
* @param expectedNumItems
* expected number of items which will be put into the filter.
* @param numBits
* expected number of bits of the filter.
* @since 3.5.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = {
buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN)
}

/**
* Builds a Bloom filter over a specified column.
*
* @param col
* the column over which the filter is built
* @param expectedNumItems
* expected number of items which will be put into the filter.
* @param numBits
* expected number of bits of the filter.
* @since 3.5.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = {
buildBloomFilter(col, expectedNumItems, numBits, Double.NaN)
}

private def buildBloomFilter(
col: Column,
expectedNumItems: Long,
numBits: Long,
fpp: Double): BloomFilter = {
def numBitsValue: Long = if (!fpp.isNaN) {
BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
} else {
numBits
}

if (fpp <= 0d || fpp >= 1d) {
throw new SparkException("False positive probability must be within range (0.0, 1.0)")
}
val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBitsValue))

val ds = sparkSession.newDataset(BinaryEncoder) { builder =>
builder.getProjectBuilder
.setInput(root)
.addExpressions(agg.expr)
}
BloomFilter.readFrom(new ByteArrayInputStream(ds.head()))
}
}

private object DataFrameStatFunctions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,91 @@ class ClientDataFrameStatSuite extends RemoteSparkSession {
assert(sketch.relativeError() === 0.001)
assert(sketch.confidence() === 0.99 +- 5e-3)
}

test("Bloom filter -- Long Column") {
val session = spark
import session.implicits._
val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toLong)
val df = data.toDF("id")
val negativeValues = Seq(-11, 1021, 32767).map(_.toLong)
checkBloomFilter(data, negativeValues, df)
}

test("Bloom filter -- Int Column") {
val session = spark
import session.implicits._
val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997)
val df = data.toDF("id")
val negativeValues = Seq(-11, 1021, 32767)
checkBloomFilter(data, negativeValues, df)
}

test("Bloom filter -- Short Column") {
val session = spark
import session.implicits._
val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toShort)
val df = data.toDF("id")
val negativeValues = Seq(-11, 1021, 32767).map(_.toShort)
checkBloomFilter(data, negativeValues, df)
}

test("Bloom filter -- Byte Column") {
val session = spark
import session.implicits._
val data = Seq(-32, -5, 1, 17, 39, 43, 101, 127).map(_.toByte)
val df = data.toDF("id")
val negativeValues = Seq(-101, 55, 113).map(_.toByte)
checkBloomFilter(data, negativeValues, df)
}

test("Bloom filter -- String Column") {
val session = spark
import session.implicits._
val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toString)
val df = data.toDF("id")
val negativeValues = Seq(-11, 1021, 32767).map(_.toString)
checkBloomFilter(data, negativeValues, df)
}

private def checkBloomFilter(
data: Seq[Any],
notContainValues: Seq[Any],
df: DataFrame): Unit = {
val filter1 = df.stat.bloomFilter("id", 1000, 0.03)
assert(filter1.expectedFpp() - 0.03 < 1e-3)
assert(data.forall(filter1.mightContain))
assert(notContainValues.forall(n => !filter1.mightContain(n)))
val filter2 = df.stat.bloomFilter("id", 1000, 64 * 5)
assert(filter2.bitSize() == 64 * 5)
assert(data.forall(filter2.mightContain))
assert(notContainValues.forall(n => !filter2.mightContain(n)))
}

test("Bloom filter -- Wrong dataType Column") {
val session = spark
import session.implicits._
val data = Range(0, 1000).map(_.toDouble)
val message = intercept[AnalysisException] {
data.toDF("id").stat.bloomFilter("id", 1000, 0.03)
}.getMessage
assert(message.contains("DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE"))
}

test("Bloom filter test invalid inputs") {
val df = spark.range(1000).toDF("id")
val message1 = intercept[SparkException] {
df.stat.bloomFilter("id", -1000, 100)
}.getMessage
assert(message1.contains("Expected insertions must be positive"))

val message2 = intercept[SparkException] {
df.stat.bloomFilter("id", 1000, -100)
}.getMessage
assert(message2.contains("Number of bits must be positive"))

val message3 = intercept[SparkException] {
df.stat.bloomFilter("id", 1000, -1.0)
}.getMessage
assert(message3.contains("False positive probability must be within range (0.0, 1.0)"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,6 @@ object CheckConnectJvmClientCompatibility {
// DataFrameNaFunctions
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.fillValue"),

// DataFrameStatFunctions
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"),

// Dataset
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.Dataset$" // private[sql]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, Mu
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
Expand Down Expand Up @@ -1738,6 +1739,36 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
val ignoreNulls = extractBoolean(children(3), "ignoreNulls")
Some(Lead(children.head, children(1), children(2), ignoreNulls))

case "bloom_filter_agg" if fun.getArgumentsCount == 3 =>
// [col, expectedNumItems: Long, numBits: Long]
val children = fun.getArgumentsList.asScala.map(transformExpression)

// Check expectedNumItems is LongType and value greater than 0L
val expectedNumItemsExpr = children(1)
val expectedNumItems = expectedNumItemsExpr match {
case Literal(l: Long, LongType) => l
case _ =>
throw InvalidPlanInput("Expected insertions must be long literal.")
}
if (expectedNumItems <= 0L) {
throw InvalidPlanInput("Expected insertions must be positive.")
}

val numBitsExpr = children(2)
// Check numBits is LongType and value greater than 0L
numBitsExpr match {
case Literal(numBits: Long, LongType) =>
if (numBits <= 0L) {
throw InvalidPlanInput("Number of bits must be positive.")
}
case _ =>
throw InvalidPlanInput("Number of bits must be long literal.")
}

Some(
new BloomFilterAggregate(children.head, expectedNumItemsExpr, numBitsExpr)
.toAggregateExpression())

case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
val timeCol = children.head
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{RUNTIME_BLOOM_FILTER_MAX_NUM_BITS, RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.BloomFilter

/**
Expand Down Expand Up @@ -78,7 +79,7 @@ case class BloomFilterAggregate(
"exprName" -> "estimatedNumItems or numBits"
)
)
case (LongType, LongType, LongType) =>
case (LongType | IntegerType | ShortType | ByteType | StringType, LongType, LongType) =>
if (!estimatedNumItemsExpression.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
Expand Down Expand Up @@ -150,6 +151,15 @@ case class BloomFilterAggregate(
Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue,
SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))

// Mark as lazy so that `updater` is not evaluated during tree transformation.
private lazy val updater: BloomFilterUpdater = child.dataType match {
case LongType => LongUpdater
case IntegerType => IntUpdater
case ShortType => ShortUpdater
case ByteType => ByteUpdater
case StringType => BinaryUpdater
}

override def first: Expression = child

override def second: Expression = estimatedNumItemsExpression
Expand All @@ -174,7 +184,7 @@ case class BloomFilterAggregate(
if (value == null) {
return buffer
}
buffer.putLong(value.asInstanceOf[Long])
updater.update(buffer, value)
buffer
}

Expand Down Expand Up @@ -224,3 +234,32 @@ object BloomFilterAggregate {
bloomFilter
}
}

private trait BloomFilterUpdater {
def update(bf: BloomFilter, v: Any): Boolean
}

private object LongUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putLong(v.asInstanceOf[Long])
}

private object IntUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putLong(v.asInstanceOf[Int])
}

private object ShortUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putLong(v.asInstanceOf[Short])
}

private object ByteUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putLong(v.asInstanceOf[Byte])
}

private object BinaryUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putBinary(v.asInstanceOf[UTF8String].getBytes)
}

0 comments on commit b9f1114

Please sign in to comment.