Skip to content

Commit

Permalink
Added Predict function
Browse files Browse the repository at this point in the history
Added predict function along with its test. Current interface is the
same of the KNN example. In other words, only the class string is
returned from the PredictOne function.
  • Loading branch information
tncardoso committed May 21, 2014
1 parent 86b18fe commit 90458d9
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 37 deletions.
105 changes: 88 additions & 17 deletions naive/bernoulli_nb.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,39 @@ type BernoulliNBClassifier struct {
base.BaseEstimator
// Logarithm of each class prior
logClassPrior map[string]float64
// Log of conditional probability for each term. This vector should be
// accessed in the following way: p(f|c) = logCondProb[c][f].
// Conditional probability for each term. This vector should be
// accessed in the following way: p(f|c) = condProb[c][f].
// Logarithm is used in order to avoid underflow.
logCondProb map[string][]float64
condProb map[string][]float64
// Number of instances in each class. This is necessary in order to
// calculate the laplace smooth value during the Predict step.
classInstances map[string]int
// Number of features in the training set
features int
}

// Create a new Bernoulli Naive Bayes Classifier. The argument 'classes'
// is the number of possible labels in the classification task.
func NewBernoulliNBClassifier() *BernoulliNBClassifier {
nb := BernoulliNBClassifier{}
nb.logCondProb = make(map[string][]float64)
nb.condProb = make(map[string][]float64)
nb.logClassPrior = make(map[string]float64)
nb.features = 0
return &nb
}

// Fill data matrix with Bernoulli Naive Bayes model. All values
// necessary for calculating prior probability and p(f_i)
func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {

// Number of features in this training set
nb.features = 0
if X.Rows > 0 {
nb.features = len(X.GetRowVectorWithoutClass(0))
}

// Number of instances in class
classInstances := make(map[string]int)
nb.classInstances = make(map[string]int)

// Number of documents with given term (by class)
docsContainingTerm := make(map[string][]int)
Expand All @@ -70,14 +82,16 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
// version is used.
for r := 0; r < X.Rows; r++ {
class := X.GetClass(r)
docVector := X.GetRowVectorWithoutClass(r)

// increment number of instances in class
t, ok := classInstances[class]
t, ok := nb.classInstances[class]
if !ok { t = 0 }
classInstances[class] = t + 1
nb.classInstances[class] = t + 1


for feat := 0; feat < X.Cols; feat++ {
v := X.Get(r, feat)
for feat := 0; feat < len(docVector); feat++ {
v := docVector[feat]
// In Bernoulli Naive Bayes the presence and absence of
// features are considered. All non-zero values are
// treated as presence.
Expand All @@ -86,7 +100,7 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
// given label.
t, ok := docsContainingTerm[class]
if !ok {
t = make([]int, X.Cols)
t = make([]int, nb.features)
docsContainingTerm[class] = t
}
t[feat] += 1
Expand All @@ -95,20 +109,77 @@ func (nb *BernoulliNBClassifier) Fit(X *base.Instances) {
}

// Pre-calculate conditional probabilities for each class
for c, _ := range classInstances {
nb.logClassPrior[c] = math.Log((float64(classInstances[c]))/float64(X.Rows))
nb.logCondProb[c] = make([]float64, X.Cols)
for feat := 0; feat < X.Cols; feat++ {
for c, _ := range nb.classInstances {
nb.logClassPrior[c] = math.Log((float64(nb.classInstances[c]))/float64(X.Rows))
nb.condProb[c] = make([]float64, nb.features)
for feat := 0; feat < nb.features; feat++ {
classTerms, _ := docsContainingTerm[c]
numDocs := classTerms[feat]
docsInClass, _ := classInstances[c]
docsInClass, _ := nb.classInstances[c]

classLogCondProb, _ := nb.logCondProb[c]
classCondProb, _ := nb.condProb[c]
// Calculate conditional probability with laplace smoothing
classLogCondProb[feat] = math.Log(float64(numDocs + 1) / float64(docsInClass + 1))
classCondProb[feat] = float64(numDocs + 1) / float64(docsInClass + 1)
}
}
}

// Use trained model to predict test vector's class. The following
// operation is used in order to score each class:
//
// classScore = log(p(c)) + \sum_{f}{log(p(f|c))}
//
// PredictOne returns the string that represents the predicted class.
//
// IMPORTANT: PredictOne panics if Fit was not called or if the
// document vector and train matrix have a different number of columns.
func (nb *BernoulliNBClassifier) PredictOne(vector []float64) string {
if nb.features == 0 {
panic("Fit should be called before predicting")
}

if len(vector) != nb.features {
panic("Different dimensions in Train and Test sets")
}

// Currently only the predicted class is returned.
bestScore := -math.MaxFloat64
bestClass := ""

for class, prior := range nb.logClassPrior {
classScore := prior
for f := 0; f < nb.features; f++ {
if vector[f] > 0 {
// Test document has feature c
classScore += math.Log(nb.condProb[class][f])
} else {
if nb.condProb[class][f] == 1.0 {
// special case when prob = 1.0, consider laplace
// smooth
classScore += math.Log(1.0 / float64(nb.classInstances[class] + 1))
} else {
classScore += math.Log(1.0 - nb.condProb[class][f])
}
}
}

if classScore > bestScore {
bestScore = classScore
bestClass = class
}
}

return bestClass
}

// Predict is just a wrapper for the PredictOne function.
//
// IMPORTANT: Predict panics if Fit was not called or if the
// document vector and train matrix have a different number of columns.
func (nb *BernoulliNBClassifier) Predict(what *base.Instances) *base.Instances {
ret := what.GeneratePredictionVector()
for i := 0; i < what.Rows; i++ {
ret.SetAttrStr(i, 0, nb.PredictOne(what.GetRowVectorWithoutClass(i)))
}
return ret
}
91 changes: 71 additions & 20 deletions naive/bernoulli_nb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@ import (
. "github.com/smartystreets/goconvey/convey"
)

func TestFit(t *testing.T) {
func TestNoFit(t *testing.T) {
Convey("Given an empty BernoulliNaiveBayes", t, func() {
nb := NewBernoulliNBClassifier()

Convey("PredictOne should panic if Fit was not called", func() {
testDoc := []float64{0.0, 1.0}
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
})
})
}

func TestSimple(t *testing.T) {
Convey("Given a simple training data", t, func() {
trainingData, err1 := base.ParseCSVToInstances("test/simple_train.csv", false)
if err1 != nil {
Expand All @@ -17,32 +28,72 @@ func TestFit(t *testing.T) {
nb := NewBernoulliNBClassifier()
nb.Fit(trainingData)

Convey("All log(prior) should be correctly calculated", func() {
logPriorBlue := nb.logClassPrior["blue"]
logPriorRed := nb.logClassPrior["red"]
Convey("Check if Fit is working as expected", func() {
Convey("All log(prior) should be correctly calculated", func() {
logPriorBlue := nb.logClassPrior["blue"]
logPriorRed := nb.logClassPrior["red"]

So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5))
So(logPriorRed, ShouldAlmostEqual, math.Log(0.5))
})

So(logPriorBlue, ShouldAlmostEqual, math.Log(0.5))
So(logPriorRed, ShouldAlmostEqual, math.Log(0.5))
Convey("'red' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.condProb["red"][0]
logCondProbTok1 := nb.condProb["red"][1]
logCondProbTok2 := nb.condProb["red"][2]

So(logCondProbTok0, ShouldAlmostEqual, 1.0)
So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0)
So(logCondProbTok2, ShouldAlmostEqual, 1.0)
})

Convey("'blue' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.condProb["blue"][0]
logCondProbTok1 := nb.condProb["blue"][1]
logCondProbTok2 := nb.condProb["blue"][2]

So(logCondProbTok0, ShouldAlmostEqual, 1.0)
So(logCondProbTok1, ShouldAlmostEqual, 1.0)
So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0)
})
})

Convey("'red' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.logCondProb["red"][0]
logCondProbTok1 := nb.logCondProb["red"][1]
logCondProbTok2 := nb.logCondProb["red"][2]
Convey("PredictOne should work as expected", func() {
Convey("Using a document with different number of cols should panic", func() {
testDoc := []float64{0.0, 2.0}
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
})

Convey("Token 1 should be a good predictor of the blue class", func() {
testDoc := []float64{0.0, 123.0, 0.0}
So(nb.PredictOne(testDoc), ShouldEqual, "blue")

testDoc = []float64{120.0, 123.0, 0.0}
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
})

Convey("Token 2 should be a good predictor of the red class", func() {
testDoc := []float64{0.0, 0.0, 120.0}
So(nb.PredictOne(testDoc), ShouldEqual, "red")

So(logCondProbTok0, ShouldAlmostEqual, math.Log(1.0))
So(logCondProbTok1, ShouldAlmostEqual, math.Log(1.0/3.0))
So(logCondProbTok2, ShouldAlmostEqual, math.Log(1.0))
testDoc = []float64{10.0, 0.0, 120.0}
So(nb.PredictOne(testDoc), ShouldEqual, "red")
})
})

Convey("'blue' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.logCondProb["blue"][0]
logCondProbTok1 := nb.logCondProb["blue"][1]
logCondProbTok2 := nb.logCondProb["blue"][2]
Convey("Predict should work as expected", func() {
testData, err := base.ParseCSVToInstances("test/simple_test.csv", false)
if err != nil {
t.Error(err)
}
predictions := nb.Predict(testData)

So(logCondProbTok0, ShouldAlmostEqual, math.Log(1.0))
So(logCondProbTok1, ShouldAlmostEqual, math.Log(1.0))
So(logCondProbTok2, ShouldAlmostEqual, math.Log(1.0/3.0))
Convey("All simple predicitions should be correct", func() {
So(predictions.GetClass(0), ShouldEqual, "blue")
So(predictions.GetClass(1), ShouldEqual, "red")
So(predictions.GetClass(2), ShouldEqual, "blue")
So(predictions.GetClass(3), ShouldEqual, "red")
})
})
})
}
4 changes: 4 additions & 0 deletions naive/test/simple_test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
0,12,0,blue
0,0,645,red
9,213,0,blue
21,0,987,red

0 comments on commit 90458d9

Please sign in to comment.