-
Notifications
You must be signed in to change notification settings - Fork 631
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
fn.random.beta
random variate (#5550)
* Add beta distribution random sampler * Add random beta tests to all-ops suites * Expose random.beta in eager ops --------- Signed-off-by: Kamil Tokarski <[email protected]>
- Loading branch information
Showing
8 changed files
with
483 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_OPERATORS_RANDOM_BETA_DISTRIBUTION_H_ | ||
#define DALI_OPERATORS_RANDOM_BETA_DISTRIBUTION_H_ | ||
|
||
#include <random> | ||
#include <vector> | ||
#include "dali/operators/random/rng_base.h" | ||
#include "dali/pipeline/operator/arg_helper.h" | ||
|
||
#define DALI_BETA_DIST_TYPES float, double | ||
|
||
namespace dali { | ||
|
||
template <typename T> | ||
struct BetaDistributionImpl { | ||
BetaDistributionImpl() : BetaDistributionImpl{1, 1} {} | ||
|
||
explicit BetaDistributionImpl(T alpha, T beta) | ||
: has_small_param_{alpha < 1 && beta < 1}, | ||
alpha_{alpha}, | ||
beta_{beta}, | ||
b_div_a_{beta_ / alpha_}, | ||
a_{has_small_param_ ? alpha_ + 1 : alpha_}, | ||
b_{has_small_param_ ? beta_ + 1 : beta_}, | ||
exp_{} {} | ||
|
||
template <typename Generator> | ||
T Generate(Generator &st) { | ||
return has_small_param_ ? GenerateGammaExp(st) : GenerateGamma(st); | ||
} | ||
|
||
private: | ||
template <typename Generator> | ||
T GenerateGamma(Generator &st) { | ||
// https://en.wikipedia.org/wiki/Beta_distribution#Random_variate_generation | ||
T a = a_(st); | ||
T b = b_(st); | ||
return a / (a + b); | ||
} | ||
|
||
template <typename Generator> | ||
T GenerateGammaExp(Generator &st) { | ||
assert(alpha_ < 1 && beta_ < 1); | ||
// For alpha >= 1 and x in [0, 1], Gamma(alpha).cdf(x) <= Exp(1).cdf(x), | ||
// i.e., the probability of sampling [0, x] is less than 1 - exp(-x) <= x, | ||
// so the probability of the denominator in `A / (A + B)` being rounded to zero | ||
// is negligable. | ||
// However, for alpha < 1, the gamma distribution has high density near 0. | ||
// Here, we use the fact that for X ~ Gamma(alpha + 1), U ~ Uniform(0, 1)^(1/alpha), | ||
// the `A = X * U` has Gamma(alpha) distribution (see Luc Devroye, Non-Uniform Random | ||
// Variate Generation, p. 182) | ||
// We can compute A / (A + B) = X * U_a / (X * U_a + Y * U_b) = | ||
// (X * U_a / max(U_a, U_b)) / (X * U_a / max(U_a, U_b) + Y * U_b / max(U_a, U_b)). | ||
// This way, we have either X or Y in the denominator sampled from "safe" Gamma | ||
// with params > 1. | ||
// The Uniform(0, 1)^(1/alpha) is concentrated near zero, for this reason | ||
// the U_a / max(U_a, U_b), U_b / max(U_a, U_b) are computed in logarithmic scale. | ||
T a = a_(st); | ||
T b = b_(st); | ||
// By inverse transform sampling, the ln(Uniform(0, 1)) = -Exp(1). | ||
T ln_ua = -exp_(st); | ||
T ln_ub = -exp_(st); | ||
// -Exp_a / alpha_ < -Exp_b / beta_, iff -Exp_a * (beta_/alpha_) < -Exp_b | ||
// but in that form, we can handle subnormal alpha or beta, that would | ||
// result in two inifnities otherwise | ||
if (ln_ua * b_div_a_ < ln_ub) { | ||
// ln_ua / alpha - ln_ub / beta | ||
T c = (ln_ua * b_div_a_ - ln_ub) / beta_; // [-inf, 0] | ||
T ac = a * std::exp(c); // a * [0, 1] | ||
return ac / (ac + b); | ||
} else { | ||
// ln_ub / beta - ln_ua / alpha | ||
T c = (ln_ub - ln_ua * b_div_a_) / beta_; // [-inf, 0] | ||
T bc = b * std::exp(c); // b * [0, 1] | ||
return a / (a + bc); | ||
} | ||
} | ||
|
||
bool has_small_param_; | ||
T alpha_, beta_, b_div_a_; | ||
std::gamma_distribution<T> a_; | ||
std::gamma_distribution<T> b_; | ||
std::exponential_distribution<T> exp_; | ||
}; | ||
|
||
template <typename Backend> | ||
class BetaDistribution : public rng::RNGBase<Backend, BetaDistribution<Backend>, false> { | ||
public: | ||
using Base = rng::RNGBase<Backend, BetaDistribution<Backend>, false>; | ||
static_assert(std::is_same_v<Backend, CPUBackend>, "GPU backend is not implemented"); | ||
|
||
explicit BetaDistribution(const OpSpec &spec) | ||
: Base(spec), alpha_("alpha", spec), beta_("beta", spec) {} | ||
|
||
void AcquireArgs(const OpSpec &spec, const Workspace &ws, int nsamples) { | ||
// read only once for build time arguments | ||
if (alpha_.HasArgumentInput() || !alpha_.size()) { | ||
alpha_.Acquire(spec, ws, alpha_.HasArgumentInput() ? nsamples : max_batch_size_); | ||
for (int sample_idx = 0; sample_idx < nsamples; sample_idx++) { | ||
auto alpha = alpha_[sample_idx].data[0]; | ||
DALI_ENFORCE(alpha > 0 && std::isfinite(alpha), | ||
make_string("The `alpha` must be a positive float32, got `", alpha, | ||
"` for sample at index ", sample_idx, ".")); | ||
} | ||
} | ||
if (beta_.HasArgumentInput() || !beta_.size()) { | ||
beta_.Acquire(spec, ws, beta_.HasArgumentInput() ? nsamples : max_batch_size_); | ||
for (int sample_idx = 0; sample_idx < nsamples; sample_idx++) { | ||
auto beta = beta_[sample_idx].data[0]; | ||
DALI_ENFORCE(beta > 0 && std::isfinite(beta), | ||
make_string("The `beta` must be a positive float32, got `", beta, | ||
"` for sample at index ", sample_idx, ".")); | ||
} | ||
} | ||
} | ||
|
||
DALIDataType DefaultDataType(const OpSpec &spec, const Workspace &ws) const { | ||
return DALI_FLOAT; | ||
} | ||
|
||
template <typename T> | ||
bool SetupDists(BetaDistributionImpl<T> *dists, const Workspace &ws, int nsamples) { | ||
for (int sample_idx = 0; sample_idx < nsamples; sample_idx++) { | ||
auto alpha = alpha_[sample_idx].data[0]; | ||
auto beta = beta_[sample_idx].data[0]; | ||
dists[sample_idx] = | ||
BetaDistributionImpl<T>{alpha_[sample_idx].data[0], beta_[sample_idx].data[0]}; | ||
} | ||
return true; | ||
} | ||
|
||
template <typename T> | ||
void RunImplTyped(Workspace &ws) { | ||
Base::template RunImplTyped<T, BetaDistributionImpl<T>>(ws); | ||
} | ||
|
||
void RunImpl(Workspace &ws) override { | ||
TYPE_SWITCH(dtype_, type2id, T, (DALI_BETA_DIST_TYPES), ( | ||
this->template RunImplTyped<T>(ws); | ||
), ( // NOLINT | ||
DALI_FAIL(make_string("Data type ", dtype_, " is not supported. " | ||
"Supported types are : ", ListTypeNames<DALI_BETA_DIST_TYPES>())); | ||
)); // NOLINT | ||
} | ||
|
||
protected: | ||
using Base::dtype_; | ||
using Base::max_batch_size_; | ||
|
||
ArgValue<float, 0> alpha_; | ||
ArgValue<float, 0> beta_; | ||
}; | ||
|
||
} // namespace dali | ||
|
||
#endif // DALI_OPERATORS_RANDOM_BETA_DISTRIBUTION_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <vector> | ||
#include "dali/operators/random/beta_distribution.h" | ||
#include "dali/operators/random/rng_base_cpu.h" | ||
|
||
namespace dali { | ||
|
||
DALI_SCHEMA(random__Beta) | ||
.DocStr(R"code(Generates a random number from ``[0, 1]`` range following the beta distribution. | ||
The beta distribution has the following probabilty distribution function: | ||
.. math:: f(x) = \frac{\Gamma(\alpha + \beta)}{\Gamma(\alpha)\Gamma(\beta)} x^{\alpha-1} (1-x)^{\beta-1} | ||
where ``Г`` is the gamma function defined as: | ||
.. math:: \Gamma(\alpha) = \int_0^\infty x^{\alpha-1} e^{-x} \, dx | ||
The operator supports ``float32`` and ``float64`` output types. | ||
The shape of the generated data can be either specified explicitly with a ``shape`` argument, | ||
or chosen to match the shape of the ``__shape_like`` input, if provided. If none are present, | ||
a single value per sample is generated. | ||
)code") | ||
.NumInput(0, 1) | ||
.InputDox(0, "shape_like", "TensorList", | ||
"Shape of this input will be used to infer the shape of the output, if provided.") | ||
.NumOutput(1) | ||
.AddOptionalArg("alpha", R"code(The alpha parameter, a positive ``float32`` scalar.)code", 1.0f, | ||
true) | ||
.AddOptionalArg("beta", R"code(The beta parameter, a positive ``float32`` scalar.)code", 1.0f, | ||
true) | ||
.AddParent("RNGAttr"); | ||
|
||
DALI_REGISTER_OPERATOR(random__Beta, BetaDistribution<CPUBackend>, CPU); | ||
|
||
} // namespace dali |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.