Skip to content

Commit

Permalink
[Minor] [SQL] Cleans up DataFrame variable names and toDF() calls
Browse files Browse the repository at this point in the history
Although we've migrated to the DataFrame API, lots of code still uses `rdd` or `srdd` as local variable names. This PR tries to address these naming inconsistencies and some other minor DataFrame related style issues.

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/4670)
<!-- Reviewable:end -->

Author: Cheng Lian <[email protected]>

Closes apache#4670 from liancheng/df-cleanup and squashes the following commits:

3e14448 [Cheng Lian] Cleans up DataFrame variable names and toDF() calls
  • Loading branch information
liancheng authored and rxin committed Feb 18, 2015
1 parent 3912d33 commit 61ab085
Show file tree
Hide file tree
Showing 37 changed files with 250 additions and 259 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ object CrossValidatorExample {
crossval.setNumFolds(2) // Use 3+ in practice

// Run cross-validation, and choose the best set of parameters.
val cvModel = crossval.fit(training.toDF)
val cvModel = crossval.fit(training.toDF())

// Prepare test documents, which are unlabeled.
val test = sc.parallelize(Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ object DeveloperApiExample {
lr.setMaxIter(10)

// Learn a LogisticRegression model. This uses the parameters stored in lr.
val model = lr.fit(training.toDF)
val model = lr.fit(training.toDF())

// Prepare test data.
val test = sc.parallelize(Seq(
Expand All @@ -67,7 +67,7 @@ object DeveloperApiExample {
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))

// Make predictions on test data.
val sumPredictions: Double = model.transform(test.toDF)
val sumPredictions: Double = model.transform(test.toDF())
.select("features", "label", "prediction")
.collect()
.map { case Row(features: Vector, label: Double, prediction: Double) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ object MovieLensALS {
.setRegParam(params.regParam)
.setNumBlocks(params.numBlocks)

val model = als.fit(training.toDF)
val model = als.fit(training.toDF())

val predictions = model.transform(test.toDF).cache()
val predictions = model.transform(test.toDF()).cache()

// Evaluate the model.
// TODO: Create an evaluator to compute RMSE.
Expand All @@ -158,7 +158,7 @@ object MovieLensALS {

// Inspect false positives.
predictions.registerTempTable("prediction")
sc.textFile(params.movies).map(Movie.parseMovie).toDF.registerTempTable("movie")
sc.textFile(params.movies).map(Movie.parseMovie).toDF().registerTempTable("movie")
sqlContext.sql(
"""
|SELECT userId, prediction.movieId, title, rating, prediction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ object SimpleParamsExample {
.setRegParam(0.01)

// Learn a LogisticRegression model. This uses the parameters stored in lr.
val model1 = lr.fit(training.toDF)
val model1 = lr.fit(training.toDF())
// Since model1 is a Model (i.e., a Transformer produced by an Estimator),
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
Expand All @@ -77,7 +77,7 @@ object SimpleParamsExample {

// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
val model2 = lr.fit(training.toDF, paramMapCombined)
val model2 = lr.fit(training.toDF(), paramMapCombined)
println("Model 2 was fit using parameters: " + model2.fittingParamMap)

// Prepare test data.
Expand All @@ -90,7 +90,7 @@ object SimpleParamsExample {
// LogisticRegression.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
model2.transform(test.toDF)
model2.transform(test.toDF())
.select("features", "label", "myProbability", "prediction")
.collect()
.foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ object SimpleTextClassificationPipeline {
.setStages(Array(tokenizer, hashingTF, lr))

// Fit the pipeline to training documents.
val model = pipeline.fit(training.toDF)
val model = pipeline.fit(training.toDF())

// Prepare test documents, which are unlabeled.
val test = sc.parallelize(Seq(
Expand All @@ -79,7 +79,7 @@ object SimpleTextClassificationPipeline {
Document(7L, "apache hadoop")))

// Make predictions on test documents.
model.transform(test.toDF)
model.transform(test.toDF())
.select("id", "text", "probability", "prediction")
.collect()
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ object DatasetExample {
println(s"Loaded ${origData.count()} instances from file: ${params.input}")

// Convert input data to DataFrame explicitly.
val df: DataFrame = origData.toDF
val df: DataFrame = origData.toDF()
println(s"Inferred schema:\n${df.schema.prettyJson}")
println(s"Converted to DataFrame with ${df.count()} records")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object RDDRelation {
// Importing the SQL context gives access to all the SQL functions and implicit conversions.
import sqlContext.implicits._

val df = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))).toDF
val df = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))).toDF()
// Any RDD containing case classes can be registered as a table. The schema of the table is
// automatically inferred using scala reflection.
df.registerTempTable("records")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ object HiveFromSpark {

// You can also register RDDs as temporary tables within a HiveContext.
val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
rdd.toDF.registerTempTable("records")
rdd.toDF().registerTempTable("records")

// Queries can then join RDD data with data stored in Hive.
println("Result of SELECT *:")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

// Create Parquet data.
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
dataRDD.saveAsParquetFile(dataPath(path))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ private[classification] object GLMClassificationModel {

// Create Parquet data.
val data = Data(weights, intercept, threshold)
sc.parallelize(Seq(data), 1).toDF.saveAsParquetFile(Loader.dataPath(path))
sc.parallelize(Seq(data), 1).toDF().saveAsParquetFile(Loader.dataPath(path))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private[regression] object GLMRegressionModel {

// Create Parquet data.
val data = Data(weights, intercept)
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
// TODO: repartition with 1 partition after SPARK-5532 gets fixed
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
val nodes = model.topNode.subtreeIterator.toSeq
val dataRDD: DataFrame = sc.parallelize(nodes)
.map(NodeData.apply(0, _))
.toDF
.toDF()
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ private[tree] object TreeEnsembleModel {
// Create Parquet data.
val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) =>
tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node))
}.toDF
}.toDF()
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
.setNumUserBlocks(numUserBlocks)
.setNumItemBlocks(numItemBlocks)
val alpha = als.getAlpha
val model = als.fit(training.toDF)
val predictions = model.transform(test.toDF)
val model = als.fit(training.toDF())
val predictions = model.transform(test.toDF())
.select("rating", "prediction")
.map { case Row(rating: Float, prediction: Float) =>
(rating.toDouble, prediction.toDouble)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* from a RDD of tuples into a [[DataFrame]] with meaningful names. For example:
* {{{
* val rdd: RDD[(Int, String)] = ...
* rdd.toDF // this implicit conversion creates a DataFrame with column name _1 and _2
* rdd.toDF() // this implicit conversion creates a DataFrame with column name _1 and _2
* rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name"
* }}}
* @group basic
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -887,8 +887,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Registers the given [[DataFrame]] as a temporary table in the catalog. Temporary tables exist
* only during the lifetime of this instance of SQLContext.
*/
private[sql] def registerDataFrameAsTable(rdd: DataFrame, tableName: String): Unit = {
catalog.registerTable(Seq(tableName), rdd.logicalPlan)
private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = {
catalog.registerTable(Seq(tableName), df.logicalPlan)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ private[sql] trait ParquetTest {
* Writes `data` to a Parquet file and reads it back as a [[DataFrame]],
* which is then passed to `f`. The Parquet file will be deleted after `f` returns.
*/
protected def withParquetRDD[T <: Product: ClassTag: TypeTag]
protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: DataFrame => Unit): Unit = {
withParquetFile(data)(path => f(sqlContext.parquetFile(path)))
Expand All @@ -120,8 +120,8 @@ private[sql] trait ParquetTest {
protected def withParquetTable[T <: Product: ClassTag: TypeTag]
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withParquetRDD(data) { rdd =>
sqlContext.registerDataFrameAsTable(rdd, tableName)
withParquetDataFrame(data) { df =>
sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@
package org.apache.spark.sql

import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.language.postfixOps
import scala.language.{implicitConversions, postfixOps}

import org.scalatest.concurrent.Eventually._

import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}

case class BigData(s: String)

Expand Down Expand Up @@ -59,15 +57,15 @@ class CachedTableSuite extends QueryTest {

test("unpersist an uncached table will not raise exception") {
assert(None == cacheManager.lookupCachedData(testData))
testData.unpersist(true)
testData.unpersist(blocking = true)
assert(None == cacheManager.lookupCachedData(testData))
testData.unpersist(false)
testData.unpersist(blocking = false)
assert(None == cacheManager.lookupCachedData(testData))
testData.persist()
assert(None != cacheManager.lookupCachedData(testData))
testData.unpersist(true)
testData.unpersist(blocking = true)
assert(None == cacheManager.lookupCachedData(testData))
testData.unpersist(false)
testData.unpersist(blocking = false)
assert(None == cacheManager.lookupCachedData(testData))
}

Expand Down
26 changes: 12 additions & 14 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql

import org.apache.spark.sql.TestData._

import scala.language.postfixOps

import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -251,20 +249,20 @@ class DataFrameSuite extends QueryTest {
Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2)))

checkAnswer(
arrayData.toDF.orderBy('data.getItem(0).asc),
arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
arrayData.toDF().orderBy('data.getItem(0).asc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)

checkAnswer(
arrayData.toDF.orderBy('data.getItem(0).desc),
arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
arrayData.toDF().orderBy('data.getItem(0).desc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)

checkAnswer(
arrayData.toDF.orderBy('data.getItem(1).asc),
arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
arrayData.toDF().orderBy('data.getItem(1).asc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)

checkAnswer(
arrayData.toDF.orderBy('data.getItem(1).desc),
arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
arrayData.toDF().orderBy('data.getItem(1).desc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}

test("limit") {
Expand All @@ -273,11 +271,11 @@ class DataFrameSuite extends QueryTest {
testData.take(10).toSeq)

checkAnswer(
arrayData.toDF.limit(1),
arrayData.toDF().limit(1),
arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))

checkAnswer(
mapData.toDF.limit(1),
mapData.toDF().limit(1),
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
}

Expand Down Expand Up @@ -411,7 +409,7 @@ class DataFrameSuite extends QueryTest {
}

test("addColumn") {
val df = testData.toDF.withColumn("newCol", col("key") + 1)
val df = testData.toDF().withColumn("newCol", col("key") + 1)
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
Expand All @@ -421,7 +419,7 @@ class DataFrameSuite extends QueryTest {
}

test("renameColumn") {
val df = testData.toDF.withColumn("newCol", col("key") + 1)
val df = testData.toDF().withColumn("newCol", col("key") + 1)
.withColumnRenamed("value", "valueRenamed")
checkAnswer(
df,
Expand Down
8 changes: 4 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}

def assertJoin(sqlString: String, c: Class[_]): Any = {
val rdd = sql(sqlString)
val physical = rdd.queryExecution.sparkPlan
val df = sql(sqlString)
val physical = df.queryExecution.sparkPlan
val operators = physical.collect {
case j: ShuffledHashJoin => j
case j: HashOuterJoin => j
Expand Down Expand Up @@ -410,8 +410,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}

test("left semi join") {
val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(rdd,
val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(df,
Row(1, 1) ::
Row(1, 2) ::
Row(2, 1) ::
Expand Down
Loading

0 comments on commit 61ab085

Please sign in to comment.