forked from mroth/weightedrand
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathweightedrand.go
125 lines (112 loc) · 4.15 KB
/
weightedrand.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// Package weightedrand contains a performant data structure and algorithm used
// to randomly select an element from some kind of list, where the chances of
// each element to be selected not being equal, but defined by relative
// "weights" (or probabilities). This is called weighted random selection.
//
// Compare this package with (github.com/jmcvetta/randutil).WeightedChoice,
// which is optimized for the single operation case. In contrast, this package
// creates a presorted cache optimized for binary search, allowing for repeated
// selections from the same set to be significantly faster, especially for large
// data sets.
package weightedrand
import (
"errors"
"math/rand"
"sort"
)
// Choice is a generic wrapper that can be used to add weights for any item.
type Choice struct {
Item interface{}
Weight uint
}
// NewChoice creates a new Choice with specified item and weight.
func NewChoice(item interface{}, weight uint) Choice {
return Choice{Item: item, Weight: weight}
}
// A Chooser caches many possible Choices in a structure designed to improve
// performance on repeated calls for weighted random selection.
type Chooser struct {
data []Choice
totals []int
max int
}
// NewChooser initializes a new Chooser for picking from the provided choices.
func NewChooser(choices ...Choice) (*Chooser, error) {
sort.Slice(choices, func(i, j int) bool {
return choices[i].Weight < choices[j].Weight
})
totals := make([]int, len(choices))
runningTotal := 0
for i, c := range choices {
weight := int(c.Weight)
if (maxInt - runningTotal) <= weight {
return nil, errWeightOverflow
}
runningTotal += weight
totals[i] = runningTotal
}
if runningTotal < 1 {
return nil, errNoValidChoices
}
return &Chooser{data: choices, totals: totals, max: runningTotal}, nil
}
const (
intSize = 32 << (^uint(0) >> 63) // cf. strconv.IntSize
maxInt = 1<<(intSize-1) - 1
)
// Possible errors returned by NewChooser, preventing the creation of a Chooser
// with unsafe runtime states.
var (
// If the sum of provided Choice weights exceed the maximum integer value
// for the current platform (e.g. math.MaxInt32 or math.MaxInt64), then
// the internal running total will overflow, resulting in an imbalanced
// distribution generating improper results.
errWeightOverflow = errors.New("sum of Choice Weights exceeds max int")
// If there are no Choices available to the Chooser with a weight >= 1,
// there are no valid choices and Pick would produce a runtime panic.
errNoValidChoices = errors.New("zero Choices with Weight >= 1")
)
// Pick returns a single weighted random Choice.Item from the Chooser.
//
// Utilizes global rand as the source of randomness.
func (c Chooser) Pick() interface{} {
r := rand.Intn(c.max) + 1
i := searchInts(c.totals, r)
return c.data[i].Item
}
// PickSource returns a single weighted random Choice.Item from the Chooser,
// utilizing the provided *rand.Rand source rs for randomness.
//
// The primary use-case for this is avoid lock contention from the global random
// source if utilizing Chooser(s) from multiple goroutines in extremely
// high-throughput situations.
//
// It is the responsibility of the caller to ensure the provided rand.Source is
// free from thread safety issues.
func (c Chooser) PickSource(rs *rand.Rand) interface{} {
r := rs.Intn(c.max) + 1
i := searchInts(c.totals, r)
return c.data[i].Item
}
// The standard library sort.SearchInts() just wraps the generic sort.Search()
// function, which takes a function closure to determine truthfulness. However,
// since this function is utilized within a for loop, it cannot currently be
// properly inlined by the compiler, resulting in non-trivial performance
// overhead.
//
// Thus, this is essentially manually inlined version. In our use case here, it
// results in a up to ~33% overall throughput increase for Pick().
func searchInts(a []int, x int) int {
// Possible further future optimization for searchInts via SIMD if we want
// to write some Go assembly code: http://0x80.pl/articles/simd-search.html
i, j := 0, len(a)
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
if a[h] < x {
i = h + 1
} else {
j = h
}
}
return i
}