forked from sjwhitworth/golearn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbernoulli_nb_test.go
142 lines (120 loc) · 4.46 KB
/
bernoulli_nb_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
package naive
import (
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/filters"
. "github.com/smartystreets/goconvey/convey"
"io/ioutil"
"os"
"testing"
)
func TestNoFit(t *testing.T) {
Convey("Given an empty BernoulliNaiveBayes", t, func() {
nb := NewBernoulliNBClassifier()
Convey("PredictOne should panic if Fit was not called", func() {
testDoc := [][]byte{[]byte{0}, []byte{1}}
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
})
})
}
func convertToBinary(src base.FixedDataGrid) base.FixedDataGrid {
// Convert to binary
b := filters.NewBinaryConvertFilter()
attrs := base.NonClassAttributes(src)
for _, a := range attrs {
b.AddAttribute(a)
}
b.Train()
ret := base.NewLazilyFilteredInstances(src, b)
return ret
}
func TestSerialize(t *testing.T) {
Convey("Given simple training/test data", t, func() {
trainingData, err := base.ParseCSVToInstances("test/simple_train.csv", false)
So(err, ShouldBeNil)
testData, err := base.ParseCSVToTemplatedInstances("test/simple_test.csv", false, trainingData)
So(err, ShouldBeNil)
nb := NewBernoulliNBClassifier()
nb.Fit(convertToBinary(trainingData))
oldPredictions, err := nb.Predict(convertToBinary(testData))
Convey("Saving the classifer should work...", func() {
f, err := ioutil.TempFile(os.TempDir(), "nb")
So(err, ShouldBeNil)
defer func() {
f.Close()
}()
err = nb.Save(f.Name())
So(err, ShouldBeNil)
Convey("Loading the classifier should work...", func() {
newNb := NewBernoulliNBClassifier()
err := newNb.Load(f.Name())
So(err, ShouldBeNil)
Convey("Predictions should match...", func() {
newPredictions, err := newNb.Predict(convertToBinary(testData))
So(err, ShouldBeNil)
So(base.InstancesAreEqual(oldPredictions, newPredictions), ShouldBeTrue)
})
})
})
})
}
func TestSimple(t *testing.T) {
Convey("Given a simple training dataset", t, func() {
trainingData, err := base.ParseCSVToInstances("test/simple_train.csv", false)
So(err, ShouldBeNil)
nb := NewBernoulliNBClassifier()
nb.Fit(convertToBinary(trainingData))
Convey("Check if Fit is working as expected", func() {
Convey("All data needed for prior should be correctly calculated", func() {
So(nb.classInstances["blue"], ShouldEqual, 2)
So(nb.classInstances["red"], ShouldEqual, 2)
So(nb.trainingInstances, ShouldEqual, 4)
})
Convey("'red' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.condProb["red"][0]
logCondProbTok1 := nb.condProb["red"][1]
logCondProbTok2 := nb.condProb["red"][2]
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0)
So(logCondProbTok2, ShouldAlmostEqual, 1.0)
})
Convey("'blue' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.condProb["blue"][0]
logCondProbTok1 := nb.condProb["blue"][1]
logCondProbTok2 := nb.condProb["blue"][2]
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
So(logCondProbTok1, ShouldAlmostEqual, 1.0)
So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0)
})
})
Convey("PredictOne should work as expected", func() {
Convey("Using a document with different number of cols should panic", func() {
testDoc := [][]byte{[]byte{0}, []byte{2}}
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
})
Convey("Token 1 should be a good predictor of the blue class", func() {
testDoc := [][]byte{[]byte{0}, []byte{1}, []byte{0}}
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
testDoc = [][]byte{[]byte{1}, []byte{1}, []byte{0}}
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
})
Convey("Token 2 should be a good predictor of the red class", func() {
testDoc := [][]byte{[]byte{0}, []byte{0}, []byte{1}}
So(nb.PredictOne(testDoc), ShouldEqual, "red")
testDoc = [][]byte{[]byte{1}, []byte{0}, []byte{1}}
So(nb.PredictOne(testDoc), ShouldEqual, "red")
})
})
Convey("Predict should work as expected", func() {
testData, err := base.ParseCSVToTemplatedInstances("test/simple_test.csv", false, trainingData)
So(err, ShouldBeNil)
predictions, err := nb.Predict(convertToBinary(testData))
So(err, ShouldBeNil)
Convey("All simple predictions should be correct", func() {
So(base.GetClass(predictions, 0), ShouldEqual, "blue")
So(base.GetClass(predictions, 1), ShouldEqual, "red")
So(base.GetClass(predictions, 2), ShouldEqual, "blue")
So(base.GetClass(predictions, 3), ShouldEqual, "red")
})
})
})
}