Skip to content

Commit

Permalink
Add truncated normal distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
Erotemic committed Jan 18, 2019
1 parent b8d4279 commit c0079aa
Showing 1 changed file with 69 additions and 1 deletion.
70 changes: 69 additions & 1 deletion imgaug/parameters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import print_function, division, absolute_import

import copy as copy_module
from collections import defaultdict
from abc import ABCMeta, abstractmethod
Expand All @@ -8,6 +7,7 @@
import six
import six.moves as sm
import scipy
import scipy.stats

from . import imgaug as ia
from .external.opensimplex import OpenSimplex
Expand Down Expand Up @@ -753,6 +753,74 @@ def __str__(self):
return "Normal(loc=%s, scale=%s)" % (self.loc, self.scale)


class TruncNormal(StochasticParameter):
"""
Parameter that resembles a truncated normal distribution.
A truncated normal distribution is very close to a normal distribution
except the domain is smoothly bounded.
This is a wrapper around scipy.stats.truncnorm.
Parameters
----------
loc : number or imgaug.parameters.StochasticParameter
The mean of the normal distribution.
If StochasticParameter, the mean will be sampled once per call
to :func:`imgaug.parameters.Normal._draw_samples`.
scale : number or imgaug.parameters.StochasticParameter
The standard deviation of the normal distribution.
If StochasticParameter, the scale will be sampled once per call
to :func:`imgaug.parameters.Normal._draw_samples`.
low : number or imgaug.parameters.StochasticParameter
The minimum value of the truncated normal distribution.
If StochasticParameter, the scale will be sampled once per call
to :func:`imgaug.parameters.Normal._draw_samples`.
high : number or imgaug.parameters.StochasticParameter
The maximum value of the truncated normal distribution.
If StochasticParameter, the scale will be sampled once per call
to :func:`imgaug.parameters.Normal._draw_samples`.
Examples
--------
>>> param = TruncNormal(0, 5.0, low=-10, high=10)
>>> samples = param.draw_samples(100, random_state=np.random.RandomState(0))
>>> assert np.all(samples >= -10)
>>> assert np.all(samples <= 10)
"""
def __init__(self, loc, scale, low=-np.inf, high=np.inf):
super(TruncNormal, self).__init__()

self.loc = handle_continuous_param(loc, "loc")
self.scale = handle_continuous_param(scale, "scale", value_range=(0, None))
self.low = handle_continuous_param(low, "low")
self.high = handle_continuous_param(high, "high")

def _draw_samples(self, size, random_state):
loc = self.loc.draw_sample(random_state=random_state)
scale = self.scale.draw_sample(random_state=random_state)
low = self.low.draw_sample(random_state=random_state)
high = self.high.draw_sample(random_state=random_state)
if low > high:
low, high = high, low
ia.do_assert(scale >= 0, "Expected scale to be in range [0, inf), got %s." % (scale,))
a = (low - loc) / scale
b = (high - loc) / scale
rv = scipy.stats.truncnorm(a=a, b=b, loc=loc, scale=scale)
return rv.rvs(size=size, random_state=random_state)

def __repr__(self):
return self.__str__()

def __str__(self):
return "Normal(loc=%s, scale=%s, low=%s, high=%s)" % (
self.loc, self.scale, self.low, self.high)


class Laplace(StochasticParameter):
"""
Parameter that resembles a (continuous) laplace distribution.
Expand Down

0 comments on commit c0079aa

Please sign in to comment.