Skip to content

Commit

Permalink
gp: tweaks + tests + plot
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Sep 20, 2017
1 parent 734ebfc commit a66f2ba
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 40 deletions.
40 changes: 35 additions & 5 deletions gp/gp.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
package gp

import (
"fmt"
"math"

"github.com/pkg/errors"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat"
Expand All @@ -13,8 +16,12 @@ import (
type GP struct {
inputs [][]float64
outputs []float64
cov Cov
noise float64

inputNames []string
outputName string

cov Cov
noise float64

alpha *mat.VecDense
l *mat.Cholesky
Expand All @@ -31,6 +38,25 @@ func New(cov Cov, noise float64) *GP {
}
}

func (gp *GP) SetNames(inputs []string, output string) {
gp.inputNames = inputs
gp.outputName = output
}

func (gp GP) Name(i int) string {
if len(gp.inputNames) > i {
return gp.inputNames[i]
}
return fmt.Sprintf("x[%d]", i)
}

func (gp GP) OutputName() string {
if len(gp.outputName) > 0 {
return gp.outputName
}
return "y"
}

func (gp GP) RawData() ([][]float64, []float64) {
inputs := make([][]float64, len(gp.inputs))
for i, s := range gp.inputs {
Expand Down Expand Up @@ -92,23 +118,27 @@ func (gp *GP) normOutputs() []float64 {
return out
}

// Estimate returns the mean and variance at the point x.
// Estimate returns the mean and standard deviation at the point x.
func (gp *GP) Estimate(x []float64) (float64, float64, error) {
if gp.dirty {
if err := gp.compute(); err != nil {
return 0, 0, err
}
}
n := len(gp.inputs)

kstar := mat.NewVecDense(n, nil)
for i := 0; i < n; i++ {
kstar.SetVec(i, gp.cov(gp.inputs[i], x))
}
mean := mat.Dot(kstar, gp.alpha)*gp.stddev + gp.mean

v := mat.NewVecDense(n, nil)
if err := gp.l.SolveVec(v, kstar); err != nil {
return 0, 0, err
}
variance := gp.cov(x, x) - mat.Dot(v, v)
return mean, variance, nil
variance := gp.cov(x, x) - mat.Dot(kstar, v)
sd := math.Sqrt(variance)

return mean, sd, nil
}
27 changes: 19 additions & 8 deletions gp/gp_test.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,40 @@
package gp_test

import (
"math"
"math/rand"
"testing"

"github.com/d4l3k/go-bayesopt/gp"
"github.com/d4l3k/go-bayesopt/gp/plot"
"github.com/gonum/floats"
)

func f(x, y float64) float64 {
return math.Cos(x/2)/2 + math.Sin(y/4)
}

func gpAdd(gp *gp.GP, x, y float64) {
gp.Add([]float64{x, y}, f(x, y))
}

func TestKnown(t *testing.T) {
gp := gp.New(gp.MaternCov, 0)
gp.Add([]float64{1}, 1)
gp.Add([]float64{2}, 2)
gp.Add([]float64{3}, 3)
gp.Add([]float64{4}, 4)
gp.Add([]float64{5}, 5)
gp.Add([]float64{10}, 10)

gpAdd(gp, 0.25, 0.75)

for i := 0; i < 20; i++ {
gpAdd(gp, rand.Float64()*2*math.Pi-math.Pi, rand.Float64()*2*math.Pi-math.Pi)
}

if _, err := plot.SaveAll(gp); err != nil {
t.Fatal(err)
}
mean, variance, err := gp.Estimate([]float64{1})
mean, variance, err := gp.Estimate([]float64{0.25, 0.75})
if err != nil {
t.Fatal(err)
}
if !floats.EqualWithinAbs(mean, 1, 0.0001) {
if !floats.EqualWithinAbs(mean, f(0.25, 0.75), 0.0001) {
t.Fatalf("got mean = %f; not 1", mean)
}
if !floats.EqualWithinAbs(variance, 0, 0.0001) {
Expand Down
44 changes: 17 additions & 27 deletions gp/plot/plot.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"fmt"
"io"
"io/ioutil"
"log"
"math"
"os"
"path"
"sort"
Expand All @@ -26,7 +24,7 @@ func SaveAll(gp *gp.GP) (string, error) {
}
dims := gp.Dims()
for i := 0; i < dims; i++ {
name := fmt.Sprintf("%d.svg", i)
name := fmt.Sprintf("%d.png", i)
fpath := path.Join(dir, name)
f, err := os.OpenFile(fpath, os.O_CREATE|os.O_WRONLY, 0755)
if err != nil {
Expand All @@ -37,7 +35,6 @@ func SaveAll(gp *gp.GP) (string, error) {
return "", err
}
f.Close()
log.Printf("%d: %s", i, fpath)
}
return dir, nil
}
Expand Down Expand Up @@ -73,24 +70,25 @@ func GP(gp *gp.GP, w io.Writer, dim int) error {
knownY[i] = p.y
}

const padding = 20

graph := chart.Chart{
Title: fmt.Sprintf("Gaussian Process: Dimension %d/%d", dim, dims),
Title: fmt.Sprintf("%s vs. %s", gp.Name(dim), gp.OutputName()),
TitleStyle: chart.StyleShow(),
XAxis: chart.XAxis{
Style: chart.Style{
Show: true,
},
Name: gp.Name(dim),
NameStyle: chart.StyleShow(),
Style: chart.StyleShow(),
},
YAxis: chart.YAxis{
Style: chart.Style{
Show: true,
},
Name: gp.OutputName(),
NameStyle: chart.StyleShow(),
Style: chart.StyleShow(),
},
Background: chart.Style{
Padding: chart.Box{
Top: padding,
Left: padding,
Top: 20,
Left: 20,
Bottom: 20,
Right: 20,
},
},
}
Expand Down Expand Up @@ -124,29 +122,21 @@ outer:
}
lowerPair = pairs[pairI]
upperPair = pairs[pairI+1]
log.Printf("j %d, xi %f, %+v %+v", j, xi, upperPair, lowerPair)
}

mid := (xi - lowerPair.x[dim]) / (upperPair.x[dim] - lowerPair.x[dim])
args := make([]float64, dims)
floats.AddScaled(args, 1-mid, lowerPair.x)
floats.AddScaled(args, mid, upperPair.x)
mean, variance, err := gp.Estimate(args)
mean, sd, err := gp.Estimate(args)
if err != nil {
return err
}
means[j] = mean
sd := math.Sqrt(math.Abs(variance))
log.Printf("sd %f, var %f, mean %f", sd, variance, mean)
if variance < 0 {
sd = -sd
}
uppers[j] = mean + sd
lowers[j] = mean - sd
}

log.Println(x, means, uppers, lowers)

graph.Series = append(
graph.Series,
chart.ContinuousSeries{
Expand All @@ -155,12 +145,12 @@ outer:
YValues: means,
},
chart.ContinuousSeries{
Name: "+σ = 1",
Name: "+",
XValues: x,
YValues: uppers,
},
chart.ContinuousSeries{
Name: "-σ = 1",
Name: "-",
XValues: x,
YValues: lowers,
},
Expand All @@ -180,7 +170,7 @@ outer:
},
)

if err := graph.Render(chart.SVG, w); err != nil {
if err := graph.Render(chart.PNG, w); err != nil {
return err
}
return nil
Expand Down

0 comments on commit a66f2ba

Please sign in to comment.