forked from RobotLocomotion/drake
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrandom_source.cc
181 lines (154 loc) · 6.17 KB
/
random_source.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#include "drake/systems/primitives/random_source.h"
#include <atomic>
#include <random>
#include <variant>
#include "drake/common/default_scalars.h"
#include "drake/common/never_destroyed.h"
namespace drake {
namespace systems {
namespace {
using Seed = RandomSource<double>::Seed;
// Stores exactly one of the three supported distribution objects. Note that
// the distribution objects hold computational state; they are not just pure
// mathematical functions.
using DistributionVariant = std::variant<std::uniform_real_distribution<double>,
std::normal_distribution<double>,
std::exponential_distribution<double>>;
// Creates a distribution object from the distribution enumeration.
DistributionVariant MakeDistributionVariant(RandomDistribution which) {
switch (which) {
case RandomDistribution::kUniform:
return std::uniform_real_distribution<double>();
case RandomDistribution::kGaussian:
return std::normal_distribution<double>();
case RandomDistribution::kExponential:
return std::exponential_distribution<double>();
}
DRAKE_UNREACHABLE();
}
// Generates real-valued (i.e., `double`) samples from some distribution. This
// serves as the abstract state of a RandomSource, which encompasses all of the
// source's state *except* for the currently-sampled output values which are
// stored as discrete state.
class SampleGenerator {
public:
DRAKE_DEFAULT_COPY_AND_MOVE_AND_ASSIGN(SampleGenerator);
SampleGenerator() = default;
SampleGenerator(Seed seed, RandomDistribution which)
: seed_(seed),
generator_(seed),
distribution_(MakeDistributionVariant(which)) {}
Seed seed() const { return seed_; }
double GenerateNext() {
switch (distribution_.index()) {
case 0:
return std::get<0>(distribution_)(generator_);
case 1:
return std::get<1>(distribution_)(generator_);
case 2:
return std::get<2>(distribution_)(generator_);
}
DRAKE_UNREACHABLE();
}
private:
Seed seed_{RandomGenerator::default_seed};
RandomGenerator generator_;
DistributionVariant distribution_;
};
// Returns a monotonically increasing integer on each call.
Seed get_next_seed() {
static never_destroyed<std::atomic<Seed>> seed(RandomGenerator::default_seed);
return seed.access()++;
}
} // namespace
template <typename T>
RandomSource<T>::RandomSource(RandomDistribution distribution, int num_outputs,
double sampling_interval_sec)
: LeafSystem<T>(SystemTypeTag<RandomSource>()),
distribution_(distribution),
sampling_interval_sec_{sampling_interval_sec},
instance_seed_{get_next_seed()} {
auto discrete_state_index = this->DeclareDiscreteState(num_outputs);
this->DeclareAbstractState(Value<SampleGenerator>());
this->DeclarePeriodicUnrestrictedUpdateEvent(sampling_interval_sec, 0.,
&RandomSource<T>::UpdateSamples);
this->DeclareStateOutputPort("output", discrete_state_index);
}
template <typename T>
RandomSource<T>::~RandomSource() {}
template <typename T>
template <typename U>
RandomSource<T>::RandomSource(const RandomSource<U>& other)
: RandomSource<T>(other.get_distribution(), other.get_output_port(0).size(),
other.sampling_interval_sec_) {}
template <typename T>
Seed RandomSource<T>::get_seed(const Context<double>& context) const {
this->ValidateContext(context);
const auto& source = context.template get_abstract_state<SampleGenerator>(0);
return source.seed();
}
template <typename T>
void RandomSource<T>::SetDefaultState(const Context<T>& context,
State<T>* state) const {
const Seed seed = fixed_seed_.value_or(instance_seed_);
SetSeed(seed, context, state);
}
template <typename T>
void RandomSource<T>::SetRandomState(const Context<T>& context, State<T>* state,
RandomGenerator* seed_generator) const {
const Seed fresh_seed = (*seed_generator)();
const Seed seed = fixed_seed_.value_or(fresh_seed);
SetSeed(seed, context, state);
}
// Writes the given seed into abstract state (replacing the existing
// SampleGenerator) and then does `UpdateSamples`.
template <typename T>
void RandomSource<T>::SetSeed(Seed seed, const Context<T>& context,
State<T>* state) const {
state->template get_mutable_abstract_state<SampleGenerator>(0) =
SampleGenerator(seed, distribution_);
UpdateSamples(context, state);
}
// Samples random values into the discrete state, using the SampleGenerator
// from the abstract state. (Note that the generator's abstract state is also
// mutated as a side effect of this method.)
template <typename T>
void RandomSource<T>::UpdateSamples(const Context<T>&, State<T>* state) const {
auto& source = state->template get_mutable_abstract_state<SampleGenerator>(0);
auto& samples = state->get_mutable_discrete_state(0);
for (int i = 0; i < samples.size(); ++i) {
samples[i] = T(source.GenerateNext());
}
}
template <typename T>
int AddRandomInputs(double sampling_interval_sec, DiagramBuilder<T>* builder) {
int count = 0;
// Note: the mutable assignment to const below looks odd, but
// there is (currently) no builder->GetSystems() method.
for (const auto* system : builder->GetMutableSystems()) {
for (int i = 0; i < system->num_input_ports(); i++) {
const systems::InputPort<T>& port = system->get_input_port(i);
// Check for the random label.
if (!port.is_random()) {
continue;
}
if (builder->IsConnectedOrExported(port)) {
continue;
}
const auto* const source = builder->template AddSystem<RandomSource<T>>(
port.get_random_type().value(), port.size(), sampling_interval_sec);
builder->Connect(source->get_output_port(0), port);
++count;
}
}
return count;
}
// clang-format off
DRAKE_DEFINE_FUNCTION_TEMPLATE_INSTANTIATIONS_ON_DEFAULT_NONSYMBOLIC_SCALARS((
&AddRandomInputs<T>
));
// clang-format on
} // namespace systems
} // namespace drake
DRAKE_DEFINE_CLASS_TEMPLATE_INSTANTIATIONS_ON_DEFAULT_NONSYMBOLIC_SCALARS(
class ::drake::systems::RandomSource);