forked from ml-explore/mlx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrandom.cpp
306 lines (275 loc) · 10 KB
/
random.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
// Copyright © 2023 Apple Inc.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "python/src/utils.h"
#include "mlx/ops.h"
#include "mlx/random.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
using namespace mlx::core::random;
void init_random(py::module_& parent_module) {
auto m = parent_module.def_submodule(
"random",
"mlx.core.random: functionality related to random number generation");
m.def(
"seed",
&seed,
"seed"_a,
R"pbdoc(
Seed the global PRNG.
Args:
seed (int): Seed for the global PRNG.
)pbdoc");
m.def(
"key",
&key,
"seed"_a,
R"pbdoc(
Get a PRNG key from a seed.
Args:
seed (int): Seed for the PRNG.
Returns:
array: The PRNG key array.
)pbdoc");
m.def(
"split",
py::overload_cast<const array&, int, StreamOrDevice>(&random::split),
"key"_a,
"num"_a = 2,
"stream"_a = none,
R"pbdoc(
Split a PRNG key into sub keys.
Args:
key (array): Input key to split.
num (int, optional): Number of sub keys. Default is 2.
Returns:
array: The array of sub keys with ``num`` as its first dimension.
)pbdoc");
m.def(
"uniform",
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
return uniform(
to_array(low),
to_array(high),
shape,
type.value_or(float32),
key,
s);
},
"low"_a = 0,
"high"_a = 1,
"shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
R"pbdoc(
Generate uniformly distributed random numbers.
The values are sampled uniformly in the half-open interval ``[low, high)``.
The lower and upper bound can be scalars or arrays and must be
broadcastable to ``shape``.
Args:
low (scalar or array, optional): Lower bound of the distribution. Default is ``0``.
high (scalar or array, optional): Upper bound of the distribution. Default is ``1``.
shape (list(int), optional): Shape of the output. Default is ``()``.
key (array, optional): A PRNG key. Default: None.
dtype (Dtype, optional): Type of the output. Default is ``float32``.
Returns:
array: The output array random values.
)pbdoc");
m.def(
"normal",
[](const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
return normal(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
R"pbdoc(
Generate normally distributed random numbers.
Args:
shape (list(int), optional): Shape of the output. Default is ``()``.
dtype (Dtype, optional): Type of the output. Default is ``float32``.
key (array, optional): A PRNG key. Default: None.
Returns:
array: The output array of random values.
)pbdoc");
m.def(
"randint",
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
return randint(
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
},
"low"_a,
"high"_a,
"shape"_a = std::vector<int>{},
"dtype"_a = int32,
"key"_a = none,
"stream"_a = none,
R"pbdoc(
Generate random integers from the given interval.
The values are sampled with equal probability from the integers in
half-open interval ``[low, high)``. The lower and upper bound can be
scalars or arrays and must be roadcastable to ``shape``.
Args:
low (scalar or array): Lower bound of the interval.
high (scalar or array): Upper bound of the interval.
shape (list(int), optional): Shape of the output. Defaults to ``()``.
dtype (Dtype, optional): Type of the output. Defaults to ``int32``.
key (array, optional): A PRNG key. Default: None.
Returns:
array: The array of random integers.
)pbdoc");
m.def(
"bernoulli",
[](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape,
const std::optional<array>& key,
StreamOrDevice s) {
auto p = to_array(p_);
if (shape.has_value()) {
return bernoulli(p, shape.value(), key, s);
} else {
return bernoulli(p, key, s);
}
},
"p"_a = 0.5,
"shape"_a = none,
"key"_a = none,
"stream"_a = none,
R"pbdoc(
Generate Bernoulli random values.
The values are sampled from the bernoulli distribution with parameter
``p``. The parameter ``p`` can be a :obj:`float` or :obj:`array` and
must be broadcastable to ``shape``.
Args:
p (float or array, optional): Parameter of the Bernoulli
distribution. Default is 0.5.
shape (list(int), optional): Shape of the output. The default
shape is ``p.shape``.
key (array, optional): A PRNG key. Default: None.
Returns:
array: The array of random integers.
)pbdoc");
m.def(
"truncated_normal",
[](const ScalarOrArray& lower_,
const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
auto lower = to_array(lower_);
auto upper = to_array(upper_);
auto t = type.value_or(float32);
if (shape_.has_value()) {
return truncated_normal(lower, upper, shape_.value(), t, key, s);
} else {
return truncated_normal(lower, upper, t, key, s);
}
},
"lower"_a,
"upper"_a,
"shape"_a = none,
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
R"pbdoc(
Generate values from a truncated normal distribution.
The values are sampled from the truncated normal distribution
on the domain ``(lower, upper)``. The bounds ``lower`` and ``upper``
can be scalars or arrays and must be broadcastable to ``shape``.
Args:
lower (scalar or array): Lower bound of the domain.
upper (scalar or array): Upper bound of the domain.
shape (list(int), optional): The shape of the output.
Default is ``()``.
dtype (Dtype, optional): The data type of the output.
Default is ``float32``.
key (array, optional): A PRNG key. Default: None.
Returns:
array: The output array of random values.
)pbdoc");
m.def(
"gumbel",
[](const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
return gumbel(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32},
"stream"_a = none,
"key"_a = none,
R"pbdoc(
Sample from the standard Gumbel distribution.
The values are sampled from a standard Gumbel distribution
which CDF ``exp(-exp(-x))``.
Args:
shape (list(int)): The shape of the output.
key (array, optional): A PRNG key. Default: None.
Returns:
array: The :class:`array` with shape ``shape`` and
distributed according to the Gumbel distribution
)pbdoc");
m.def(
"categorical",
[](const array& logits,
int axis,
const std::optional<std::vector<int>> shape,
const std::optional<int> num_samples,
const std::optional<array>& key,
StreamOrDevice s) {
if (shape.has_value() && num_samples.has_value()) {
throw std::invalid_argument(
"[categorical] At most one of shape or num_samples can be specified.");
} else if (shape.has_value()) {
return categorical(logits, axis, shape.value(), key, s);
} else if (num_samples.has_value()) {
return categorical(logits, axis, num_samples.value(), key, s);
} else {
return categorical(logits, axis, key, s);
}
},
"logits"_a,
"axis"_a = -1,
"shape"_a = none,
"num_samples"_a = none,
"key"_a = none,
"stream"_a = none,
R"pbdoc(
Sample from a categorical distribution.
The values are sampled from the categorical distribution specified by
the unnormalized values in ``logits``. Note, at most one of ``shape``
or ``num_samples`` can be specified. If both are ``None``, the output
has the same shape as ``logits`` with the ``axis`` dimension removed.
Args:
logits (array): The *unnormalized* categorical distribution(s).
axis (int, optional): The axis which specifies the distribution.
Default is ``-1``.
shape (list(int), optional): The shape of the output. This must
be broadcast compatable with ``logits.shape`` with the ``axis``
dimension removed. Default: ``None``
num_samples (int, optional): The number of samples to draw from each
of the categorical distributions in ``logits``. The output will have
``num_samples`` in the last dimension. Default: ``None``.
key (array, optional): A PRNG key. Default: None.
Returns:
array: The ``shape``-sized output array with type ``uint32``.
)pbdoc");
}