Skip to content

Commit

Permalink
RejectionParam: adds param that allows for rejection sampling for mor…
Browse files Browse the repository at this point in the history
…e complex distributions
  • Loading branch information
d4l3k committed Aug 23, 2018
1 parent 55004c7 commit 40afc5e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
22 changes: 22 additions & 0 deletions params.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,25 @@ func truncateSample(p Param, f func() float64) float64 {
}
return math.Min(math.Max(sample, min), max)
}

// RejectionParam samples from Param and then uses F to decide whether or not to
// reject the sample. This is typically used with a UniformParam. F should
// output a value between 0 and 1 indicating the proportion of samples this
// point should be accepted. If F always outputs 0, Sample will get stuck in an
// infinite loop.
type RejectionParam struct {
Param

F func(x float64) float64
}

// Sample implements Param.
func (p RejectionParam) Sample() float64 {
for {
x := p.Param.Sample()
y := p.F(x)
if rand.Float64() < y {
return x
}
}
}
29 changes: 26 additions & 3 deletions params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ func TestParams(t *testing.T) {
t.Parallel()

cases := []struct {
p Param
name string
max, min float64
p Param
name string
max, min float64
sampleMax float64
}{
{
p: UniformParam{
Expand Down Expand Up @@ -55,6 +56,25 @@ func TestParams(t *testing.T) {
max: 100,
min: 10,
},
{
p: RejectionParam{
Param: UniformParam{
Name: "rejection uniform",
Max: 100,
Min: 10,
},
F: func(x float64) float64 {
if x > 50 {
return 0
}
return 1
},
},
name: "rejection uniform",
max: 100,
min: 10,
sampleMax: 50,
},
}

for i, c := range cases {
Expand Down Expand Up @@ -84,6 +104,9 @@ func TestParams(t *testing.T) {
if sample < c.min || sample > c.max {
t.Errorf("%d. %+v.Sample() = %v; outside bounds", i, c.p, sample)
}
if c.sampleMax != 0 && sample > c.sampleMax {
t.Errorf("%d. %+v.Sample() = %v; outside sample max %f", i, c.p, sample, c.sampleMax)
}
}
}
}
Expand Down

0 comments on commit 40afc5e

Please sign in to comment.