Skip to content

Commit

Permalink
params: added NormalParam, renamed LinearParam to UniformParam, added…
Browse files Browse the repository at this point in the history
… param tests, made tests run in parallel
  • Loading branch information
d4l3k committed Jul 6, 2018
1 parent b1c4ed6 commit 6d3ec3e
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 9 deletions.
4 changes: 3 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
language: go

go:
- 1.8.x
- 1.9.x
- 1.10.x
- master

script: go test -v -cover ./...
6 changes: 6 additions & 0 deletions bayesopt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
)

func TestOptimizer(t *testing.T) {
t.Parallel()

X := LinearParam{
Max: 10,
Min: -10,
Expand Down Expand Up @@ -50,6 +52,8 @@ func TestOptimizer(t *testing.T) {
}

func TestOptimizerMax(t *testing.T) {
t.Parallel()

X := LinearParam{
Max: 10,
Min: -10,
Expand Down Expand Up @@ -95,6 +99,8 @@ func TestOptimizerMax(t *testing.T) {
}

func TestOptimizerBounds(t *testing.T) {
t.Parallel()

X := LinearParam{
Max: 10,
Min: 5,
Expand Down
74 changes: 66 additions & 8 deletions params.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
package bayesopt

import "math/rand"
import (
"math"
"math/rand"
)

// SampleTries is the number of tries a sample function should try before
// truncating the samples to the boundaries.
var SampleTries = 1000

// Param represents a parameter that can be optimized.
type Param interface {
Expand All @@ -15,30 +22,81 @@ type Param interface {
Sample() float64
}

var _ Param = LinearParam{}
var _ Param = UniformParam{}

// LinearParam is a UniformParam. Deprecated.
type LinearParam = UniformParam

// LinearParam is a uniformly distributed parameter between Max and Min.
type LinearParam struct {
// UniformParam is a uniformly distributed parameter between Max and Min.
type UniformParam struct {
Name string
Max, Min float64
}

// GetName implements Param.
func (p LinearParam) GetName() string {
func (p UniformParam) GetName() string {
return p.Name
}

// GetMax implements Param.
func (p LinearParam) GetMax() float64 {
func (p UniformParam) GetMax() float64 {
return p.Max
}

// GetMin implements Param.
func (p LinearParam) GetMin() float64 {
func (p UniformParam) GetMin() float64 {
return p.Min
}

// Sample implements Param.
func (p LinearParam) Sample() float64 {
func (p UniformParam) Sample() float64 {
return rand.Float64()*(p.Max-p.Min) + p.Min
}

var _ Param = NormalParam{}

// NormalParam is a normally distributed parameter with Mean and StdDev between
// Max and Min. The Max and Min parameters use discard sampling to find a point
// between them. Set them to be math.Inf(1) and math.Inf(-1) to disable the
// bounds.
type NormalParam struct {
Name string
Max, Min float64
Mean, StdDev float64
}

// GetName implements Param.
func (p NormalParam) GetName() string {
return p.Name
}

// GetMax implements Param.
func (p NormalParam) GetMax() float64 {
return p.Max
}

// GetMin implements Param.
func (p NormalParam) GetMin() float64 {
return p.Min
}

// Sample implements Param.
func (p NormalParam) Sample() float64 {
return truncateSample(p, func() float64 {
return rand.NormFloat64()*p.StdDev + p.Mean
})
}

func truncateSample(p Param, f func() float64) float64 {
max := p.GetMax()
min := p.GetMin()

var sample float64
for i := 0; i < SampleTries; i++ {
sample = f()
if sample >= min && sample <= max {
return sample
}
}
return math.Min(math.Max(sample, min), max)
}
138 changes: 138 additions & 0 deletions params_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package bayesopt

import "testing"

func TestParams(t *testing.T) {
t.Parallel()

cases := []struct {
p Param
name string
max, min float64
}{
{
p: UniformParam{
Name: "uniform",
Max: 10,
Min: 1,
},
name: "uniform",
max: 10,
min: 1,
},
{
p: NormalParam{
Name: "normal",
Max: 10,
Min: -10,
Mean: 0,
StdDev: 10,
},
name: "normal",
max: 10,
min: -10,
},
{
p: NormalParam{
Name: "normal",
Max: 10,
Min: 0,
Mean: 1,
StdDev: 5,
},
name: "normal",
max: 10,
min: 0,
},
}

for i, c := range cases {
{
out := c.p.GetName()
want := c.name
if out != want {
t.Errorf("%d. %+v.GetName() = %q; wanted %q", i, c.p, out, want)
}
}
{
out := c.p.GetMax()
want := c.max
if out != want {
t.Errorf("%d. %+v.GetMax() = %v; wanted %v", i, c.p, out, want)
}
}
{
out := c.p.GetMin()
want := c.min
if out != want {
t.Errorf("%d. %+v.GetMin() = %v; wanted %v", i, c.p, out, want)
}
}
for j := 0; j < 1000; j++ {
sample := c.p.Sample()
if sample < c.min || sample > c.max {
t.Errorf("%d. %+v.Sample() = %v; outside bounds", i, c.p, sample)
}
}
}
}

func TestTruncateSample(t *testing.T) {
t.Parallel()

var count int

cases := []struct {
p Param
f func() float64
want float64
count int
}{
{
p: UniformParam{"", 10, 5},
f: func() float64 {
return 0
},
want: 5,
count: 1000,
},
{
p: UniformParam{"", 10, 5},
f: func() float64 {
return 100
},
want: 10,
count: 1000,
},
{
p: UniformParam{"", 10, 5},
f: func() float64 {
return 7
},
want: 7,
count: 1,
},
{
p: UniformParam{"", 10, 5},
f: func() float64 {
return float64(count)
},
want: 5,
count: 5,
},
}

for i, c := range cases {
count = 0
out := truncateSample(c.p, func() float64 {
count++
return c.f()
})
if out != c.want {
t.Errorf("%d. truncateSample(%+v, ...) = %f; not %f", i, c.p, out, c.want)
}
if count != c.count {
t.Errorf("%d. count = %d; not %d", i, count, c.count)
}
}
}

0 comments on commit 6d3ec3e

Please sign in to comment.