Skip to content

Commit

Permalink
[MINOR][ML] Rename weights to coefficients for examples/DeveloperApiE…
Browse files Browse the repository at this point in the history
…xample

Rename ```weights``` to ```coefficients``` for examples/DeveloperApiExample.

cc mengxr jkbradley

Author: Yanbo Liang <[email protected]>

Closes #10280 from yanboliang/spark-coefficients.
  • Loading branch information
yanboliang authored and jkbradley committed Dec 16, 2015
1 parent bc1ff9f commit b24c12d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public static void main(String[] args) throws Exception {
}
if (sumPredictions != 0.0) {
throw new Exception("MyJavaLogisticRegression predicted something other than 0," +
" even though all weights are 0!");
" even though all coefficients are 0!");
}

jsc.stop();
Expand Down Expand Up @@ -149,12 +149,12 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) {
// Extract columns from data using helper method.
JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();

// Do learning to estimate the weight vector.
// Do learning to estimate the coefficients vector.
int numFeatures = oldDataset.take(1).get(0).features().size();
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
Vector coefficients = Vectors.zeros(numFeatures); // Learning would happen here.

// Create a model, and return it.
return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this);
return new MyJavaLogisticRegressionModel(uid(), coefficients).setParent(this);
}

@Override
Expand All @@ -173,12 +173,12 @@ public MyJavaLogisticRegression copy(ParamMap extra) {
class MyJavaLogisticRegressionModel
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {

private Vector weights_;
public Vector weights() { return weights_; }
private Vector coefficients_;
public Vector coefficients() { return coefficients_; }

public MyJavaLogisticRegressionModel(String uid, Vector weights) {
public MyJavaLogisticRegressionModel(String uid, Vector coefficients) {
this.uid_ = uid;
this.weights_ = weights;
this.coefficients_ = coefficients;
}

private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg");
Expand Down Expand Up @@ -208,7 +208,7 @@ public String uid() {
* modifier.
*/
public Vector predictRaw(Vector features) {
double margin = BLAS.dot(features, weights_);
double margin = BLAS.dot(features, coefficients_);
// There are 2 classes (binary classification), so we return a length-2 vector,
// where index i corresponds to class i (i = 0, 1).
return Vectors.dense(-margin, margin);
Expand All @@ -222,7 +222,7 @@ public Vector predictRaw(Vector features) {
/**
* Number of features the model was trained on.
*/
public int numFeatures() { return weights_.size(); }
public int numFeatures() { return coefficients_.size(); }

/**
* Create a copy of the model.
Expand All @@ -235,7 +235,7 @@ public Vector predictRaw(Vector features) {
*/
@Override
public MyJavaLogisticRegressionModel copy(ParamMap extra) {
return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra)
return copyValues(new MyJavaLogisticRegressionModel(uid(), coefficients_), extra)
.setParent(parent());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ object DeveloperApiExample {
prediction
}.sum
assert(sumPredictions == 0.0,
"MyLogisticRegression predicted something other than 0, even though all weights are 0!")
"MyLogisticRegression predicted something other than 0, even though all coefficients are 0!")

sc.stop()
}
Expand Down Expand Up @@ -124,12 +124,12 @@ private class MyLogisticRegression(override val uid: String)
// Extract columns from data using helper method.
val oldDataset = extractLabeledPoints(dataset)

// Do learning to estimate the weight vector.
// Do learning to estimate the coefficients vector.
val numFeatures = oldDataset.take(1)(0).features.size
val weights = Vectors.zeros(numFeatures) // Learning would happen here.
val coefficients = Vectors.zeros(numFeatures) // Learning would happen here.

// Create a model, and return it.
new MyLogisticRegressionModel(uid, weights).setParent(this)
new MyLogisticRegressionModel(uid, coefficients).setParent(this)
}

override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra)
Expand All @@ -142,7 +142,7 @@ private class MyLogisticRegression(override val uid: String)
*/
private class MyLogisticRegressionModel(
override val uid: String,
val weights: Vector)
val coefficients: Vector)
extends ClassificationModel[Vector, MyLogisticRegressionModel]
with MyLogisticRegressionParams {

Expand All @@ -163,7 +163,7 @@ private class MyLogisticRegressionModel(
* confidence for that label.
*/
override protected def predictRaw(features: Vector): Vector = {
val margin = BLAS.dot(features, weights)
val margin = BLAS.dot(features, coefficients)
// There are 2 classes (binary classification), so we return a length-2 vector,
// where index i corresponds to class i (i = 0, 1).
Vectors.dense(-margin, margin)
Expand All @@ -173,7 +173,7 @@ private class MyLogisticRegressionModel(
override val numClasses: Int = 2

/** Number of features the model was trained on. */
override val numFeatures: Int = weights.size
override val numFeatures: Int = coefficients.size

/**
* Create a copy of the model.
Expand All @@ -182,7 +182,7 @@ private class MyLogisticRegressionModel(
* This is used for the default implementation of [[transform()]].
*/
override def copy(extra: ParamMap): MyLogisticRegressionModel = {
copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent)
copyValues(new MyLogisticRegressionModel(uid, coefficients), extra).setParent(parent)
}
}
// scalastyle:on println

0 comments on commit b24c12d

Please sign in to comment.