forked from databricks/Spark-The-Definitive-Guide
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAdvanced_Analytics_and_Machine_Learning-Chapter_26_Classification.py
82 lines (49 loc) · 1.64 KB
/
Advanced_Analytics_and_Machine_Learning-Chapter_26_Classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
bInput = spark.read.format("parquet").load("/data/binary-classification")\
.selectExpr("features", "cast(label as double) as label")
# COMMAND ----------
from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression()
print lr.explainParams() # see all parameters
lrModel = lr.fit(bInput)
# COMMAND ----------
print lrModel.coefficients
print lrModel.intercept
# COMMAND ----------
summary = lrModel.summary
print summary.areaUnderROC
summary.roc.show()
summary.pr.show()
# COMMAND ----------
summary.objectiveHistory
# COMMAND ----------
from pyspark.ml.classification import DecisionTreeClassifier
dt = DecisionTreeClassifier()
print dt.explainParams()
dtModel = dt.fit(bInput)
# COMMAND ----------
from pyspark.ml.classification import RandomForestClassifier
rfClassifier = RandomForestClassifier()
print rfClassifier.explainParams()
trainedModel = rfClassifier.fit(bInput)
# COMMAND ----------
from pyspark.ml.classification import GBTClassifier
gbtClassifier = GBTClassifier()
print gbtClassifier.explainParams()
trainedModel = gbtClassifier.fit(bInput)
# COMMAND ----------
from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes()
print nb.explainParams()
trainedModel = nb.fit(bInput.where("label != 0"))
# COMMAND ----------
from pyspark.mllib.evaluation import BinaryClassificationMetrics
out = model.transform(bInput)\
.select("prediction", "label")\
.rdd.map(lambda x: (float(x[0]), float(x[1])))
metrics = BinaryClassificationMetrics(out)
# COMMAND ----------
print metrics.areaUnderPR
print metrics.areaUnderROC
print "Receiver Operating Characteristic"
metrics.roc.toDF().show()
# COMMAND ----------