Skip to content

Commit

Permalink
ENH: stats: make ortho_group freezable (scipy#15653)
Browse files Browse the repository at this point in the history
* ENH: stats: make ortho_group freezable
  • Loading branch information
NamamiShanker authored Feb 26, 2022
1 parent bfffbd4 commit dbc2c3d
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
62 changes: 62 additions & 0 deletions scipy/stats/_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3444,6 +3444,19 @@ class ortho_group_gen(multi_rv_generic):
rvs(dim=None, size=1, random_state=None)
Draw random samples from O(N).
Parameters
----------
dim : scalar
Dimension of matrices
seed : {None, int, np.random.RandomState, np.random.Generator}, optional
Used for drawing random variates.
If `seed` is `None`, the `~np.random.RandomState` singleton is used.
If `seed` is an int, a new ``RandomState`` instance is used, seeded
with seed.
If `seed` is already a ``RandomState`` or ``Generator`` instance,
then that object is used.
Default is `None`.
Notes
-----
This class is closely related to `special_ortho_group`.
Expand Down Expand Up @@ -3472,12 +3485,29 @@ class ortho_group_gen(multi_rv_generic):
This generates one random matrix from O(3). It is orthogonal and
has a determinant of +1 or -1.
Alternatively, the object may be called (as a function) to fix the `dim`
parameter, returning a "frozen" ortho_group random variable:
>>> rv = ortho_group(5)
>>> # Frozen object with the same methods but holding the
>>> # dimension parameter fixed.
See Also
--------
special_ortho_group
"""

def __init__(self, seed=None):
super().__init__(seed)
self.__doc__ = doccer.docformat(self.__doc__)

def __call__(self, dim=None, seed=None):
"""Create a frozen O(N) distribution.
See `ortho_group_frozen` for more information.
"""
return ortho_group_frozen(dim, seed=seed)

def _process_parameters(self, dim):
"""Dimension N must be specified; it cannot be inferred."""
if dim is None or not np.isscalar(dim) or dim <= 1 or dim != int(dim):
Expand Down Expand Up @@ -3528,6 +3558,38 @@ def rvs(self, dim, size=1, random_state=None):
ortho_group = ortho_group_gen()


class ortho_group_frozen(multi_rv_frozen):
def __init__(self, dim=None, seed=None):
"""Create a frozen O(N) distribution.
Parameters
----------
dim : scalar
Dimension of matrices
seed : {None, int, `numpy.random.Generator`,
`numpy.random.RandomState`}, optional
If `seed` is None (or `np.random`), the `numpy.random.RandomState`
singleton is used.
If `seed` is an int, a new ``RandomState`` instance is used,
seeded with `seed`.
If `seed` is already a ``Generator`` or ``RandomState`` instance
then that instance is used.
Examples
--------
>>> from scipy.stats import ortho_group
>>> g = ortho_group(5)
>>> x = g.rvs()
"""
self._dist = ortho_group_gen(seed)
self.dim = self._dist._process_parameters(dim)

def rvs(self, size=1, random_state=None):
return self._dist.rvs(self.dim, size, random_state)


class random_correlation_gen(multi_rv_generic):
r"""A random correlation matrix.
Expand Down
12 changes: 12 additions & 0 deletions scipy/stats/tests/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,18 @@ def test_invalid_dim(self):
assert_raises(ValueError, ortho_group.rvs, 1)
assert_raises(ValueError, ortho_group.rvs, 2.5)

def test_frozen_matrix(self):
dim = 7
frozen = ortho_group(dim)
frozen_seed = ortho_group(dim, seed=1234)

rvs1 = frozen.rvs(random_state=1234)
rvs2 = ortho_group.rvs(dim, random_state=1234)
rvs3 = frozen_seed.rvs(size=1)

assert_equal(rvs1, rvs2)
assert_equal(rvs1, rvs3)

def test_det_and_ortho(self):
xs = [[ortho_group.rvs(dim)
for i in range(10)]
Expand Down

0 comments on commit dbc2c3d

Please sign in to comment.