Skip to content

Commit

Permalink
[SPARK-15084][PYTHON][SQL] Use builder pattern to create SparkSession…
Browse files Browse the repository at this point in the history
… in PySpark.

## What changes were proposed in this pull request?

This is a python port of corresponding Scala builder pattern code. `sql.py` is modified as a target example case.

## How was this patch tested?

Manual.

Author: Dongjoon Hyun <[email protected]>

Closes apache#12860 from dongjoon-hyun/SPARK-15084.
  • Loading branch information
dongjoon-hyun authored and Andrew Or committed May 4, 2016
1 parent c1839c9 commit 0903a18
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 21 deletions.
35 changes: 15 additions & 20 deletions examples/src/main/python/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,28 @@
import os
import sys

from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.types import Row, StructField, StructType, StringType, IntegerType


if __name__ == "__main__":
sc = SparkContext(appName="PythonSQL")
sqlContext = SQLContext(sc)
spark = SparkSession.builder.appName("PythonSQL").getOrCreate()

# RDD is created from a list of rows
some_rdd = sc.parallelize([Row(name="John", age=19),
Row(name="Smith", age=23),
Row(name="Sarah", age=18)])
# Infer schema from the first row, create a DataFrame and print the schema
some_df = sqlContext.createDataFrame(some_rdd)
# A list of Rows. Infer schema from the first row, create a DataFrame and print the schema
rows = [Row(name="John", age=19), Row(name="Smith", age=23), Row(name="Sarah", age=18)]
some_df = spark.createDataFrame(rows)
some_df.printSchema()

# Another RDD is created from a list of tuples
another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)])
# A list of tuples
tuples = [("John", 19), ("Smith", 23), ("Sarah", 18)]
# Schema with two fields - person_name and person_age
schema = StructType([StructField("person_name", StringType(), False),
StructField("person_age", IntegerType(), False)])
# Create a DataFrame by applying the schema to the RDD and print the schema
another_df = sqlContext.createDataFrame(another_rdd, schema)
another_df = spark.createDataFrame(tuples, schema)
another_df.printSchema()
# root
# |-- age: integer (nullable = true)
# |-- age: long (nullable = true)
# |-- name: string (nullable = true)

# A JSON dataset is pointed to by path.
Expand All @@ -57,24 +52,24 @@
else:
path = sys.argv[1]
# Create a DataFrame from the file(s) pointed to by path
people = sqlContext.jsonFile(path)
people = spark.read.json(path)
# root
# |-- person_name: string (nullable = false)
# |-- person_age: integer (nullable = false)

# The inferred schema can be visualized using the printSchema() method.
people.printSchema()
# root
# |-- age: IntegerType
# |-- name: StringType
# |-- age: long (nullable = true)
# |-- name: string (nullable = true)

# Register this DataFrame as a table.
people.registerAsTable("people")
people.registerTempTable("people")

# SQL statements can be run by using the sql methods provided by sqlContext
teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
teenagers = spark.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")

for each in teenagers.collect():
print(each[0])

sc.stop()
spark.stop()
91 changes: 90 additions & 1 deletion python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import warnings
from functools import reduce
from threading import RLock

if sys.version >= '3':
basestring = unicode = str
Expand Down Expand Up @@ -58,16 +59,98 @@ def toDF(self, schema=None, sampleRatio=None):


class SparkSession(object):
"""Main entry point for Spark SQL functionality.
"""The entry point to programming Spark with the Dataset and DataFrame API.
A SparkSession can be used create :class:`DataFrame`, register :class:`DataFrame` as
tables, execute SQL over tables, cache tables, and read parquet files.
To create a SparkSession, use the following builder pattern:
>>> spark = SparkSession.builder \
.master("local") \
.appName("Word Count") \
.config("spark.some.config.option", "some-value") \
.getOrCreate()
:param sparkContext: The :class:`SparkContext` backing this SparkSession.
:param jsparkSession: An optional JVM Scala SparkSession. If set, we do not instantiate a new
SparkSession in the JVM, instead we make all calls to this object.
"""

class Builder(object):
"""Builder for :class:`SparkSession`.
"""

_lock = RLock()
_options = {}

@since(2.0)
def config(self, key=None, value=None, conf=None):
"""Sets a config option. Options set using this method are automatically propagated to
both :class:`SparkConf` and :class:`SparkSession`'s own configuration.
For an existing SparkConf, use `conf` parameter.
>>> from pyspark.conf import SparkConf
>>> SparkSession.builder.config(conf=SparkConf())
<pyspark.sql.session...
For a (key, value) pair, you can omit parameter names.
>>> SparkSession.builder.config("spark.some.config.option", "some-value")
<pyspark.sql.session...
:param key: a key name string for configuration property
:param value: a value for configuration property
:param conf: an instance of :class:`SparkConf`
"""
with self._lock:
if conf is None:
self._options[key] = str(value)
else:
for (k, v) in conf.getAll():
self._options[k] = v
return self

@since(2.0)
def master(self, master):
"""Sets the Spark master URL to connect to, such as "local" to run locally, "local[4]"
to run locally with 4 cores, or "spark://master:7077" to run on a Spark standalone
cluster.
:param master: a url for spark master
"""
return self.config("spark.master", master)

@since(2.0)
def appName(self, name):
"""Sets a name for the application, which will be shown in the Spark web UI.
:param name: an application name
"""
return self.config("spark.app.name", name)

@since(2.0)
def enableHiveSupport(self):
"""Enables Hive support, including connectivity to a persistent Hive metastore, support
for Hive serdes, and Hive user-defined functions.
"""
return self.config("spark.sql.catalogImplementation", "hive")

@since(2.0)
def getOrCreate(self):
"""Gets an existing :class:`SparkSession` or, if there is no existing one, creates a new
one based on the options set in this builder.
"""
with self._lock:
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql.context import SQLContext
sparkConf = SparkConf()
for key, value in self._options.items():
sparkConf.set(key, value)
sparkContext = SparkContext.getOrCreate(sparkConf)
return SQLContext.getOrCreate(sparkContext).sparkSession

builder = Builder()

_instantiatedContext = None

@ignore_unicode_prefix
Expand Down Expand Up @@ -445,6 +528,12 @@ def read(self):
"""
return DataFrameReader(self._wrapped)

@since(2.0)
def stop(self):
"""Stop the underlying :class:`SparkContext`.
"""
self._sc.stop()


def _test():
import os
Expand Down

0 comments on commit 0903a18

Please sign in to comment.