diff --git a/params.go b/params.go index a4ca7f6..d5ce6ec 100644 --- a/params.go +++ b/params.go @@ -134,3 +134,25 @@ func truncateSample(p Param, f func() float64) float64 { } return math.Min(math.Max(sample, min), max) } + +// RejectionParam samples from Param and then uses F to decide whether or not to +// reject the sample. This is typically used with a UniformParam. F should +// output a value between 0 and 1 indicating the proportion of samples this +// point should be accepted. If F always outputs 0, Sample will get stuck in an +// infinite loop. +type RejectionParam struct { + Param + + F func(x float64) float64 +} + +// Sample implements Param. +func (p RejectionParam) Sample() float64 { + for { + x := p.Param.Sample() + y := p.F(x) + if rand.Float64() < y { + return x + } + } +} diff --git a/params_test.go b/params_test.go index 51f66f4..f3d4349 100644 --- a/params_test.go +++ b/params_test.go @@ -6,9 +6,10 @@ func TestParams(t *testing.T) { t.Parallel() cases := []struct { - p Param - name string - max, min float64 + p Param + name string + max, min float64 + sampleMax float64 }{ { p: UniformParam{ @@ -55,6 +56,25 @@ func TestParams(t *testing.T) { max: 100, min: 10, }, + { + p: RejectionParam{ + Param: UniformParam{ + Name: "rejection uniform", + Max: 100, + Min: 10, + }, + F: func(x float64) float64 { + if x > 50 { + return 0 + } + return 1 + }, + }, + name: "rejection uniform", + max: 100, + min: 10, + sampleMax: 50, + }, } for i, c := range cases { @@ -84,6 +104,9 @@ func TestParams(t *testing.T) { if sample < c.min || sample > c.max { t.Errorf("%d. %+v.Sample() = %v; outside bounds", i, c.p, sample) } + if c.sampleMax != 0 && sample > c.sampleMax { + t.Errorf("%d. %+v.Sample() = %v; outside sample max %f", i, c.p, sample, c.sampleMax) + } } } }