Skip to content

Commit

Permalink
add concurrentMap test and clean up documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
piazzamp committed Jul 23, 2016
1 parent 1c495e2 commit 0b98886
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 10 deletions.
20 changes: 10 additions & 10 deletions text/bayes.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ type NaiveBayes struct {
// Words holds a map of words
// to their corresponding Word
// structure
Words histogram `json:"words"`
Words concurrentMap `json:"words"`

// Count holds the number of times
// class i was seen as Count[i]
Expand Down Expand Up @@ -161,25 +161,25 @@ type NaiveBayes struct {
Output io.Writer
}

// histogram allows conncurrency-friendly map access via its
// exported Get and Set methods
type histogram struct {
// concurrentMap allows concurrency-friendly map
// access via its exported Get and Set methods
type concurrentMap struct {
sync.RWMutex
words map[string]Word
}

// Get looks up a word from h's Word map, it should be used in
// place of a direct map lookup
// the only caveat here is that it will always return the 'success' boolean
func (h *histogram) Get(w string) (Word, bool) {
// Get looks up a word from h's Word map and should be used
// in place of a direct map lookup. The only caveat is that
// it will always return the 'success' boolean
func (h *concurrentMap) Get(w string) (Word, bool) {
h.RLock()
result, ok := h.words[w]
h.RUnlock()
return result, ok
}

// Set sets word k's value to v in h's Word map
func (h *histogram) Set(k string, v Word) {
func (h *concurrentMap) Set(k string, v Word) {
h.Lock()
h.words[k] = v
h.Unlock()
Expand Down Expand Up @@ -217,7 +217,7 @@ type Word struct {
// comply with the transform.RemoveFunc interface
func NewNaiveBayes(stream <-chan base.TextDatapoint, classes uint8, sanitize func(rune) bool) *NaiveBayes {
return &NaiveBayes{
Words: histogram{sync.RWMutex{}, make(map[string]Word)},
Words: concurrentMap{sync.RWMutex{}, make(map[string]Word)},
Count: make([]uint64, classes),
Probabilities: make([]float64, classes),

Expand Down
52 changes: 52 additions & 0 deletions text/bayes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"io/ioutil"
"os"
"strings"
"sync"
"testing"

"github.com/cdipaolo/goml/base"
Expand Down Expand Up @@ -228,3 +230,53 @@ func TestPersistPerceptronShouldPass1(t *testing.T) {
class = model.Predict("My mother is in Los Angeles") // 0
assert.EqualValues(t, 1, class, "Class should be 0")
}

// make sure that calling predict while the model is still training does
// not cause a runtime panic because of concurrent map reads & writes
func TestConcurrentPredictionAndLearningShouldNotFail(t *testing.T) {
c := make(chan base.TextDatapoint, 100)
model := NewNaiveBayes(c, 2, base.OnlyWords)
errors := make(chan error)

// fill the buffer
var i uint8
for i = 0; i < 99; i++ {
c <- base.TextDatapoint{
X: strings.Repeat("a whole bunch of words that will take some time to iterate through", 50),
Y: i % 2,
}
}

// spin off a "long" running loop of predicting
// and then start another goroutine for OnlineLearn
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
// fmt.Println("beginning predicting")
for i := 0; i < 500; i++ {
model.Predict(strings.Repeat("some stuff that might be in the training data like iterate", 25))
}
// fmt.Println("done predicting")
}()

wg.Add(1)
go func() {
defer wg.Done()
// fmt.Println("beginning learning")
model.OnlineLearn(errors)
// fmt.Println("done learning")
}()

go func() {
for err, more := <-errors; more; err, more = <-errors {
if err != nil {
t.Logf("Error passed: %s\n", err.Error())
t.Fail()
}
}
}()

close(c)
wg.Wait()
}

0 comments on commit 0b98886

Please sign in to comment.