Skip to content

Commit

Permalink
[SPARK-5612][SQL] Move DataFrame implicit functions into SQLContext.i…
Browse files Browse the repository at this point in the history
…mplicits.

Author: Reynold Xin <[email protected]>

Closes #4386 from rxin/df-implicits and squashes the following commits:

9d96606 [Reynold Xin] style fix
edd296b [Reynold Xin] ReplSuite
1c946ab [Reynold Xin] [SPARK-5612][SQL] Move DataFrame implicit functions into SQLContext.implicits.
  • Loading branch information
rxin committed Feb 5, 2015
1 parent 9d3a75e commit 7d789e1
Show file tree
Hide file tree
Showing 28 changed files with 60 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ object CrossValidatorExample {
val conf = new SparkConf().setAppName("CrossValidatorExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext._
import sqlContext.implicits._

// Prepare training documents, which are labeled.
val training = sparkContext.parallelize(Seq(
val training = sc.parallelize(Seq(
LabeledDocument(0L, "a b c d e spark", 1.0),
LabeledDocument(1L, "b d", 0.0),
LabeledDocument(2L, "spark f g h", 1.0),
Expand Down Expand Up @@ -92,7 +92,7 @@ object CrossValidatorExample {
val cvModel = crossval.fit(training)

// Prepare test documents, which are unlabeled.
val test = sparkContext.parallelize(Seq(
val test = sc.parallelize(Seq(
Document(4L, "spark i j k"),
Document(5L, "l m n"),
Document(6L, "mapreduce spark"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ object MovieLensALS {
val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext._
import sqlContext.implicits._

val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ object SimpleParamsExample {
val conf = new SparkConf().setAppName("SimpleParamsExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext._
import sqlContext.implicits._

// Prepare training data.
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans
// into DataFrames, where it uses the bean metadata to infer the schema.
val training = sparkContext.parallelize(Seq(
val training = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
Expand Down Expand Up @@ -81,7 +81,7 @@ object SimpleParamsExample {
println("Model 2 was fit using parameters: " + model2.fittingParamMap)

// Prepare test documents.
val test = sparkContext.parallelize(Seq(
val test = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ object SimpleTextClassificationPipeline {
val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext._
import sqlContext.implicits._

// Prepare training documents, which are labeled.
val training = sparkContext.parallelize(Seq(
val training = sc.parallelize(Seq(
LabeledDocument(0L, "a b c d e spark", 1.0),
LabeledDocument(1L, "b d", 0.0),
LabeledDocument(2L, "spark f g h", 1.0),
Expand All @@ -71,7 +71,7 @@ object SimpleTextClassificationPipeline {
val model = pipeline.fit(training)

// Prepare test documents, which are unlabeled.
val test = sparkContext.parallelize(Seq(
val test = sc.parallelize(Seq(
Document(4L, "spark i j k"),
Document(5L, "l m n"),
Document(6L, "mapreduce spark"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ object DatasetExample {
val conf = new SparkConf().setAppName(s"DatasetExample with $params")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext._ // for implicit conversions
import sqlContext.implicits._ // for implicit conversions

// Load input data
val origData: RDD[LabeledPoint] = params.dataFormat match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ object RDDRelation {
val sqlContext = new SQLContext(sc)

// Importing the SQL context gives access to all the SQL functions and implicit conversions.
import sqlContext._
import sqlContext.implicits._

val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
// Any RDD containing case classes can be registered as a table. The schema of the table is
Expand All @@ -41,15 +41,15 @@ object RDDRelation {

// Once tables have been registered, you can run SQL queries over them.
println("Result of SELECT *:")
sql("SELECT * FROM records").collect().foreach(println)
sqlContext.sql("SELECT * FROM records").collect().foreach(println)

// Aggregation queries are also supported.
val count = sql("SELECT COUNT(*) FROM records").collect().head.getLong(0)
val count = sqlContext.sql("SELECT COUNT(*) FROM records").collect().head.getLong(0)
println(s"COUNT(*): $count")

// The results of SQL queries are themselves RDDs and support all normal RDD functions. The
// items in the RDD are of type Row, which allows you to access each column by ordinal.
val rddFromSql = sql("SELECT key, value FROM records WHERE key < 10")
val rddFromSql = sqlContext.sql("SELECT key, value FROM records WHERE key < 10")

println("Result of RDD.map:")
rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println)
Expand All @@ -68,7 +68,7 @@ object RDDRelation {

// These files can also be registered as tables.
parquetFile.registerTempTable("parquetFile")
sql("SELECT * FROM parquetFile").collect().foreach(println)
sqlContext.sql("SELECT * FROM parquetFile").collect().foreach(println)

sc.stop()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ object HiveFromSpark {
// HiveContext. When not configured by the hive-site.xml, the context automatically
// creates metastore_db and warehouse in the current directory.
val hiveContext = new HiveContext(sc)
import hiveContext._
import hiveContext.implicits._
import hiveContext.sql

sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
sql(s"LOAD DATA LOCAL INPATH '${kv1File.getAbsolutePath}' INTO TABLE src")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class ALSModel private[ml] (
def setPredictionCol(value: String): this.type = set(predictionCol, value)

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
import dataset.sqlContext.createDataFrame
import dataset.sqlContext.implicits._
val map = this.paramMap ++ paramMap
val users = userFactors.toDataFrame("id", "features")
val items = itemFactors.toDataFrame("id", "features")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
numItemBlocks: Int = 3,
targetRMSE: Double = 0.05): Unit = {
val sqlContext = this.sqlContext
import sqlContext.createDataFrame
import sqlContext.implicits._
val als = new ALS()
.setRank(rank)
.setRegParam(regParam)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,12 @@ class ReplSuite extends FunSuite {
assertDoesNotContain("Exception", output)
}

test("SPARK-2576 importing SQLContext.createDataFrame.") {
test("SPARK-2576 importing SQLContext.implicits._") {
// We need to use local-cluster to test this case.
val output = runInterpreter("local-cluster[1,1,512]",
"""
|val sqlContext = new org.apache.spark.sql.SQLContext(sc)
|import sqlContext.createDataFrame
|import sqlContext.implicits._
|case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDataFrame.collect()
""".stripMargin)
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
/**
* (Scala-specific)
* Implicit methods available in Scala for converting common Scala objects into [[DataFrame]]s.
*/
object implicits {
Expand All @@ -192,8 +193,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @group userf
*/
// TODO: Remove implicit here.
implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
SparkPlan.currentContext.set(self)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ trait ParquetTest {
protected def withParquetFile[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: String => Unit): Unit = {
import sqlContext.implicits._
withTempPath { file =>
sparkContext.parallelize(data).saveAsParquetFile(file.getCanonicalPath)
f(file.getCanonicalPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ case class BigData(s: String)
class CachedTableSuite extends QueryTest {
TestData // Load test tables.

import org.apache.spark.sql.test.TestSQLContext.implicits._

def rddIdOf(tableName: String): Int = {
val executedPlan = table(tableName).queryExecution.executedPlan
executedPlan.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import org.apache.spark.sql.Dsl._
import org.apache.spark.sql.types._

/* Implicits */
import org.apache.spark.sql.test.TestSQLContext.{createDataFrame, logicalPlanToSparkQuery}
import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
import org.apache.spark.sql.test.TestSQLContext.implicits._

import scala.language.postfixOps

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
// Make sure the tables are loaded.
TestData

import org.apache.spark.sql.test.TestSQLContext.implicits._

var origZone: TimeZone = _
override protected def beforeAll() {
origZone = TimeZone.getDefault
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ case class ComplexReflectData(
dataField: Data)

class ScalaReflectionRelationSuite extends FunSuite {

import org.apache.spark.sql.test.TestSQLContext.implicits._

test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))
Expand Down
3 changes: 1 addition & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ import java.sql.Timestamp
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Dsl._
import org.apache.spark.sql.test._
import org.apache.spark.sql.test.TestSQLContext.implicits._

/* Implicits */
import org.apache.spark.sql.test.TestSQLContext._

case class TestData(key: Int, value: String)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dsl._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class InMemoryColumnarQuerySuite extends QueryTest {
// Make sure the tables are loaded.
TestData

import org.apache.spark.sql.test.TestSQLContext.implicits._

test("simple columnar query") {
val plan = executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
val originalColumnBatchSize = conf.columnBatchSize
val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning

import org.apache.spark.sql.test.TestSQLContext.implicits._

override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
}

test("fixed-length decimals") {
import org.apache.spark.sql.test.TestSQLContext.implicits._

def makeDecimalRDD(decimal: DecimalType): DataFrame =
sparkContext
.parallelize(0 to 1000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import org.apache.spark.sql.hive.test.TestHive._
case class TestData(key: Int, value: String)

class InsertIntoHiveTableSuite extends QueryTest {
import org.apache.spark.sql.hive.test.TestHive.implicits._

val testData = TestHive.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString)))
testData.registerTempTable("testData")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ import org.apache.spark.sql.hive.test.TestHive._
* Tests for persisting tables created though the data sources API into the metastore.
*/
class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {

import org.apache.spark.sql.hive.test.TestHive.implicits._

override def afterEach(): Unit = {
reset()
if (ctasPath.exists()) Utils.deleteRecursively(ctasPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
private val originalTimeZone = TimeZone.getDefault
private val originalLocale = Locale.getDefault

import org.apache.spark.sql.hive.test.TestHive.implicits._

override def beforeAll() {
TestHive.cacheTables = true
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.{sparkContext, sql}
import org.apache.spark.sql.hive.test.TestHive.implicits._

case class Nested(a: Int, B: Int)
case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ case class ListStringCaseClass(l: Seq[String])
* A test suite for Hive custom UDFs.
*/
class HiveUdfSuite extends QueryTest {
import TestHive._

import TestHive.{udf, sql}
import TestHive.implicits._

test("spark sql udf test that returns a struct") {
udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ case class Nested3(f3: Int)
* valid, but Hive currently cannot execute it.
*/
class SQLQuerySuite extends QueryTest {

import org.apache.spark.sql.hive.test.TestHive.implicits._

test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") {
checkAnswer(
sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
var partitionedTableDir: File = null
var partitionedTableDirWithKey: File = null

import org.apache.spark.sql.hive.test.TestHive.implicits._

override def beforeAll(): Unit = {
partitionedTableDir = File.createTempFile("parquettests", "sparksql")
partitionedTableDir.delete()
Expand Down

0 comments on commit 7d789e1

Please sign in to comment.