Skip to content

Commit

Permalink
exposing the biases for multi-class logistic regression (dotnet#1224)
Browse files Browse the repository at this point in the history
* exposing the biases for multi-class logistic regression

* Adding the GetBiases to the cookbook samples.
  • Loading branch information
sfilipi authored Oct 11, 2018
1 parent c5746a8 commit 659686c
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 2 deletions.
3 changes: 3 additions & 0 deletions docs/code/MlNetCookBook.md
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,9 @@ var model = learningPipeline.Fit(trainData);
VBuffer<float>[] weights = null;
predictor.GetWeights(ref weights, out int numClasses);

// similarly we can also inspect the biases for the 3 classes
var biases = pred.GetBiases();

// Inspect the normalizer scales.
Console.WriteLine(string.Join(" ", normScales));
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ protected override VBuffer<float> InitializeWeightsFromPredictor(MulticlassLogis
if (srcPredictor.InputType.VectorSize != NumFeatures)
throw Contracts.Except("The input training data must have the same features used to train the input predictor.");

return InitializeWeights(srcPredictor.DenseWeightsEnumerable(), srcPredictor.BiasesEnumerable());
return InitializeWeights(srcPredictor.DenseWeightsEnumerable(), srcPredictor.GetBiases());
}

protected override MulticlassLogisticRegressionPredictor CreatePredictor()
Expand Down Expand Up @@ -952,7 +952,10 @@ internal IEnumerable<float> DenseWeightsEnumerable()
}
}

internal IEnumerable<float> BiasesEnumerable()
/// <summary>
/// Gets the biases for the logistic regression predictor.
/// </summary>
public IEnumerable<float> GetBiases()
{
return _biases;
}
Expand Down
3 changes: 3 additions & 0 deletions test/Microsoft.ML.StaticPipelineTesting/Training.cs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ public void SdcaMulticlass()
foreach (var w in weights)
Assert.True(w.Length == 4);

var biases = pred.GetBiases();
Assert.True(biases.Count() == 3);

var data = model.Read(dataSource);

// Just output some data on the schema for fun.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ private void TrainAndInspectWeights(string dataPath)
VBuffer<float>[] weights = null;
predictor.GetWeights(ref weights, out int numClasses);

// similarly we can also inspect the biases for the 3 classes
var biases = predictor.GetBiases();

// Inspect the normalizer scales.
Console.WriteLine(string.Join(" ", normScales));
}
Expand Down

0 comments on commit 659686c

Please sign in to comment.