Skip to content

Commit

Permalink
Add a exponential distribution using the rate
Browse files Browse the repository at this point in the history
  • Loading branch information
aherbert committed Aug 10, 2023
1 parent aab807f commit 9f4cdad
Show file tree
Hide file tree
Showing 2 changed files with 336 additions and 0 deletions.
207 changes: 207 additions & 0 deletions gdsc-core/src/main/java/uk/ac/sussex/gdsc/core/math/Distributions.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/*-
* #%L
* Genome Damage and Stability Centre Core Package
*
* Contains core utilities for image analysis and is used by:
*
* GDSC ImageJ Plugins - Microscopy image analysis
*
* GDSC SMLM ImageJ Plugins - Single molecule localisation microscopy (SMLM)
* %%
* Copyright (C) 2011 - 2023 Alex Herbert
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* <http://www.gnu.org/licenses/gpl-3.0.html>.
* #L%
*/

package uk.ac.sussex.gdsc.core.math;

import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
import org.apache.commons.statistics.distribution.ContinuousDistribution;

/**
* Contains methods for probability distributions.
*/
public final class Distributions {
/**
* Implement the exponential distribution parameterized by the rate (1 / mean).
*/
private static class RateExponentialDistribution implements ContinuousDistribution {
/** Support lower bound. */
private static final double SUPPORT_LO = 0;
/** Support upper bound. */
private static final double SUPPORT_HI = Double.POSITIVE_INFINITY;
/** ln(2). */
private static final double LN_2 = 0.6931471805599453094172;

/** The rate of this distribution. */
private final double rate;
/** The logarithm of the rate, stored to reduce computing time. */
private final double logRate;

/**
* @param rate Rate of this distribution.
*/
private RateExponentialDistribution(double rate) {
this.rate = rate;
logRate = Math.log(rate);
}

@Override
public double density(double x) {
if (x < SUPPORT_LO) {
return 0;
}
return Math.exp(-x * rate) * rate;
}

@Override
public double logDensity(double x) {
if (x < SUPPORT_LO) {
return Double.NEGATIVE_INFINITY;
}
return logRate - x * rate;
}

@Override
public double cumulativeProbability(double x) {
if (x <= SUPPORT_LO) {
return 0;
}
return -Math.expm1(-x * rate);
}

@Override
public double survivalProbability(double x) {
if (x <= SUPPORT_LO) {
return 1;
}
return Math.exp(-x * rate);
}

@Override
public double probability(double x0, double x1) {
if (x0 > x1) {
throw new IllegalArgumentException(
String.format("Lower bound %s > upper bound %s", x0, x1));
}
// Use the survival probability when in the upper domain:
final double median = LN_2 / rate;
if (x0 >= median) {
return survivalProbability(x0) - survivalProbability(x1);
}
return cumulativeProbability(x1) - cumulativeProbability(x0);
}

@Override
public double inverseCumulativeProbability(double p) {
checkProbability(p);
if (p == 1) {
return Double.POSITIVE_INFINITY;
}
// Subtract from zero to prevent returning -0.0 for p=-0.0
return 0 - Math.log1p(-p) / rate;
}

@Override
public double inverseSurvivalProbability(double p) {
checkProbability(p);
if (p == 0) {
return Double.POSITIVE_INFINITY;
}
// Subtract from zero to prevent returning -0.0 for p=1
return 0 - Math.log(p) / rate;
}

@Override
public double getMean() {
return 1 / rate;
}

@Override
public double getVariance() {
return 1 / (rate * rate);
}

@Override
public double getSupportLowerBound() {
return SUPPORT_LO;
}

@Override
public double getSupportUpperBound() {
return SUPPORT_HI;
}

@Override
public Sampler createSampler(UniformRandomProvider rng) {
// Exponential distribution sampler.
// Handle the edge case where the mean is infinite.
final double mean = getMean();
if (Double.isInfinite(mean)) {
final ZigguratSampler.Exponential sampler = ZigguratSampler.Exponential.of(rng);
return () -> sampler.sample() / rate;
}
return ZigguratSampler.Exponential.of(rng, mean)::sample;
}

/**
* Check the probability {@code p} is in the interval {@code [0, 1]}.
*
* @param p Probability
* @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
*/
private static void checkProbability(double p) {
if (p >= 0 && p <= 1) {
return;
}
// Out-of-range or NaN
throw new IllegalArgumentException("Invalid probability: " + p);
}
}

/** No public construction. */
private Distributions() {}

/**
* Return a new exponential distribution.
*
* <p>The probability density function of X is:
*
* <p> f(x; lambda) = lambda e^{-x * lambda}
*
* <p>This implementation uses the rate parameter {@code lambda} which is the inverse scale of the
* distribution. A common alternative parameterization uses the scale parameter {@code mu} which
* is the mean of the distribution. The distribution can be be created using
* {@code lambda = 1 / mu}. For a parameterisation using the mean see
* {@link org.apache.commons.statistics.distribution.ExponentialDistribution}.
*
* <p>Note this implementation is within a few ULP of a parameterisation using the mean. Only the
* log density may be very different; this occurs as the x value approaches the mean.
*
* @param lambda the rate parameter
* @return the continuous distribution
* @throws IllegalArgumentException if {@code rate <= 0}.
* @see <a href="https://en.wikipedia.org/wiki/Exponential_distribution">Exponential distribution
* (Wikipedia)</a>
*/
public static ContinuousDistribution exponential(double lambda) {
if (lambda <= 0) {
throw new IllegalArgumentException("Invalid rate: " + lambda);
}
return new RateExponentialDistribution(lambda);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*-
* #%L
* Genome Damage and Stability Centre Core Package
*
* Contains core utilities for image analysis and is used by:
*
* GDSC ImageJ Plugins - Microscopy image analysis
*
* GDSC SMLM ImageJ Plugins - Single molecule localisation microscopy (SMLM)
* %%
* Copyright (C) 2011 - 2023 Alex Herbert
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* <http://www.gnu.org/licenses/gpl-3.0.html>.
* #L%
*/

package uk.ac.sussex.gdsc.core.math;

import org.apache.commons.statistics.distribution.ContinuousDistribution;
import org.apache.commons.statistics.distribution.ContinuousDistribution.Sampler;
import org.apache.commons.statistics.distribution.ExponentialDistribution;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import uk.ac.sussex.gdsc.core.utils.SimpleArrayUtils;
import uk.ac.sussex.gdsc.test.api.Predicates;
import uk.ac.sussex.gdsc.test.api.TestAssertions;
import uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate;
import uk.ac.sussex.gdsc.test.rng.RngFactory;

/**
* Test for {@link GeometryUtils}.
*/
@SuppressWarnings({"javadoc"})
class DistributionsTest {
@ParameterizedTest
@ValueSource(doubles = {0.5, 1, 4})
void canComputeExponentialDistribution(double mean) {
double rate = 1 / mean;
ExponentialDistribution ed = ExponentialDistribution.of(mean);
ContinuousDistribution d = Distributions.exponential(rate);
DoubleDoubleBiPredicate test = Predicates.doublesAreUlpClose(5);
Assertions.assertEquals(ed.getSupportLowerBound(), d.getSupportLowerBound(), "lower bound");
Assertions.assertEquals(ed.getSupportUpperBound(), d.getSupportUpperBound(), "upper bound");
TestAssertions.assertTest(ed.getMean(), d.getMean(), test, "mean");
TestAssertions.assertTest(ed.getVariance(), d.getVariance(), test, "variance");

double[] x = SimpleArrayUtils.newArray(10, mean / 5, mean / 2);
for (int i = 0; i < x.length; i++) {
TestAssertions.assertTest(ed.density(x[i]), d.density(x[i]), test, "density");
TestAssertions.assertTest(ed.logDensity(x[i]), d.logDensity(x[i]), test, "logDensity");
if (i + 1 < x.length) {
TestAssertions.assertTest(ed.probability(x[i], x[i + 1]), d.probability(x[i], x[i + 1]),
test, "probability");
}
TestAssertions.assertTest(ed.cumulativeProbability(x[i]), d.cumulativeProbability(x[i]), test,
"cumulativeProbability");
TestAssertions.assertTest(ed.survivalProbability(x[i]), d.survivalProbability(x[i]), test,
"survivalProbability");
}

double[] p = SimpleArrayUtils.newArray(9, 0, 1 / 8.0);
for (int i = 0; i < p.length; i++) {
TestAssertions.assertTest(ed.inverseCumulativeProbability(p[i]),
d.inverseCumulativeProbability(p[i]), test, "inverseCumulativeProbability");
TestAssertions.assertTest(ed.inverseSurvivalProbability(p[i]),
d.inverseSurvivalProbability(p[i]), test, "inverseSurvivalProbability");
}

// Requires rate to be exactly invertible
if (mean == 1 / rate) {
Sampler s1 = ed.createSampler(RngFactory.createWithFixedSeed());
Sampler s2 = d.createSampler(RngFactory.createWithFixedSeed());
for (int i = 0; i < 10; i++) {
Assertions.assertEquals(s1.sample(), s2.sample());
}
}
}

@Test
void canComputeExponentialDistributionEdgeCases() {
ExponentialDistribution ed = ExponentialDistribution.of(1);
ContinuousDistribution d = Distributions.exponential(1);
for (double x : new double[] {-1, 0}) {
Assertions.assertEquals(ed.density(x), d.density(x), "density");
// Allow -0.0 == 0.0
Assertions.assertEquals(ed.logDensity(x), d.logDensity(x), 0.0, "logDensity");
Assertions.assertEquals(ed.cumulativeProbability(x), d.cumulativeProbability(x),
"cumulativeProbability");
Assertions.assertEquals(ed.survivalProbability(x), d.survivalProbability(x),
"survivalProbability");
}
Assertions.assertThrows(IllegalArgumentException.class, () -> Distributions.exponential(0));
Assertions.assertThrows(IllegalArgumentException.class, () -> Distributions.exponential(-1));
Assertions.assertThrows(IllegalArgumentException.class, () -> d.probability(1, 0.5));
Assertions.assertThrows(IllegalArgumentException.class,
() -> d.inverseCumulativeProbability(-1));
Assertions.assertThrows(IllegalArgumentException.class,
() -> d.inverseCumulativeProbability(1.5));
Assertions.assertThrows(IllegalArgumentException.class, () -> d.inverseSurvivalProbability(-1));
Assertions.assertThrows(IllegalArgumentException.class,
() -> d.inverseSurvivalProbability(1.5));
}

@Test
void canComputeExponentialDistributionSamplesWithInfiniteMean() {
double rate = Double.MIN_NORMAL / 4;
ContinuousDistribution d = Distributions.exponential(rate);
Assertions.assertEquals(Double.POSITIVE_INFINITY, d.getMean(), "mean");
Sampler s1 = ExponentialDistribution.of(1).createSampler(RngFactory.createWithFixedSeed());
Sampler s2 = d.createSampler(RngFactory.createWithFixedSeed());
for (int i = 0; i < 10; i++) {
Assertions.assertEquals(s1.sample() / rate, s2.sample());
}
}
}

0 comments on commit 9f4cdad

Please sign in to comment.