Skip to content

Commit

Permalink
Delete references to deprecated DeterministicPosterior and Determinis…
Browse files Browse the repository at this point in the history
…ticSampler (pytorch#2391)

Summary:
## Motivation

Delete references to deprecated DeterministicPosterior and DeterministicSampler, replacing `DeterministicPosterior` with `EnsemblePosterior` where appropriate. These were deprecated before 0.9.0, so now that we are past 0.11.0, they can be reaped.

Pull Request resolved: pytorch#2391

Test Plan:
Replaced `DeterministicPosterior` with `EnsemblePosterior` in tests

## Related PRs

pytorch#1636

Reviewed By: Balandat

Differential Revision: D59057165

Pulled By: esantorella

fbshipit-source-id: 70a8405d0e75d414a685192808e8c0b18a6aca92
  • Loading branch information
esantorella authored and facebook-github-bot committed Jul 2, 2024
1 parent 32bdfda commit ee06209
Show file tree
Hide file tree
Showing 10 changed files with 24 additions and 208 deletions.
2 changes: 0 additions & 2 deletions botorch/posteriors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.posteriors.fully_bayesian import (
FullyBayesianPosterior,
GaussianMixturePosterior,
Expand All @@ -18,7 +17,6 @@
from botorch.posteriors.transformed import TransformedPosterior

__all__ = [
"DeterministicPosterior",
"GaussianMixturePosterior",
"FullyBayesianPosterior",
"GPyTorchPosterior",
Expand Down
86 changes: 0 additions & 86 deletions botorch/posteriors/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,89 +3,3 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Deterministic (degenerate) posteriors. Used in conjunction with deterministic
models.
"""

from __future__ import annotations

from typing import Optional
from warnings import warn

import torch
from botorch.posteriors.posterior import Posterior
from torch import Tensor


class DeterministicPosterior(Posterior):
r"""Deterministic posterior.
[DEPRECATED] Use `EnsemblePosterior` instead.
"""

def __init__(self, values: Tensor) -> None:
r"""
Args:
values: Values of the samples produced by this posterior.
"""
warn(
"`DeterministicPosterior` is marked for deprecation, consider using "
"`EnsemblePosterior`.",
DeprecationWarning,
)
self.values = values

@property
def device(self) -> torch.device:
r"""The torch device of the posterior."""
return self.values.device

@property
def dtype(self) -> torch.dtype:
r"""The torch dtype of the posterior."""
return self.values.dtype

def _extended_shape(
self, sample_shape: torch.Size = torch.Size() # noqa: B008
) -> torch.Size:
r"""Returns the shape of the samples produced by the posterior with
the given `sample_shape`.
"""
return sample_shape + self.values.shape

@property
def mean(self) -> Tensor:
r"""The mean of the posterior as a `(b) x n x m`-dim Tensor."""
return self.values

@property
def variance(self) -> Tensor:
r"""The variance of the posterior as a `(b) x n x m`-dim Tensor.
As this is a deterministic posterior, this is a tensor of zeros.
"""
return torch.zeros_like(self.values)

def rsample(
self,
sample_shape: Optional[torch.Size] = None,
) -> Tensor:
r"""Sample from the posterior (with gradients).
For the deterministic posterior, this just returns the values expanded
to the requested shape.
Args:
sample_shape: A `torch.Size` object specifying the sample shape. To
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
of `n` samples each, set to `torch.Size([b, n])`.
Returns:
Samples from the posterior, a tensor of shape
`self._extended_shape(sample_shape=sample_shape)`.
"""
if sample_shape is None:
sample_shape = torch.Size([1])
return self.values.expand(self._extended_shape(sample_shape))
2 changes: 0 additions & 2 deletions botorch/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

from botorch.sampling.base import MCSampler
from botorch.sampling.deterministic import DeterministicSampler
from botorch.sampling.get_sampler import get_sampler
from botorch.sampling.list_sampler import ListSampler
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
Expand All @@ -20,7 +19,6 @@


__all__ = [
"DeterministicSampler",
"ForkedRNGSampler",
"get_sampler",
"IIDNormalSampler",
Expand Down
32 changes: 0 additions & 32 deletions botorch/sampling/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,3 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
A dummy sampler for use with deterministic models.
"""

from __future__ import annotations

from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.sampling.stochastic_samplers import StochasticSampler


class DeterministicSampler(StochasticSampler):
r"""A sampler that simply calls `posterior.rsample`, intended to be used with
`DeterministicModel` & `DeterministicPosterior`.
[DEPRECATED] - Use `IndexSampler` in conjunction with `EnsemblePosterior`
instead of `DeterministicSampler` with `DeterministicPosterior`.
This is effectively signals that `StochasticSampler` is safe to use with
deterministic models since their output is deterministic by definition.
"""

def _update_base_samples(
self, posterior: DeterministicPosterior, base_sampler: DeterministicSampler
) -> None:
r"""This is a no-op since there are no base samples to update.
Args:
posterior: The posterior for which the base samples are constructed.
base_sampler: The base sampler to retrieve the base samples from.
"""
return
12 changes: 0 additions & 12 deletions botorch/sampling/get_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@

import torch
from botorch.logging import logger
from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.posteriors.ensemble import EnsemblePosterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from botorch.posteriors.posterior_list import PosteriorList
from botorch.posteriors.torch import TorchPosterior
from botorch.posteriors.transformed import TransformedPosterior
from botorch.sampling.base import MCSampler
from botorch.sampling.deterministic import DeterministicSampler
from botorch.sampling.index_sampler import IndexSampler
from botorch.sampling.list_sampler import ListSampler
from botorch.sampling.normal import (
Expand Down Expand Up @@ -119,16 +117,6 @@ def _get_sampler_list(
return ListSampler(*samplers)


@GetSampler.register(DeterministicPosterior)
def _get_sampler_deterministic(
posterior: DeterministicPosterior,
sample_shape: torch.Size,
seed: Optional[int] = None,
) -> MCSampler:
r"""Get the dummy `DeterministicSampler` for the `DeterministicPosterior`."""
return DeterministicSampler(sample_shape=sample_shape, seed=seed)


@GetSampler.register(EnsemblePosterior)
def _get_sampler_ensemble(
posterior: EnsemblePosterior,
Expand Down
7 changes: 5 additions & 2 deletions test/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from botorch.exceptions.errors import InputDataError
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.model import Model, ModelDict, ModelList
from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.posteriors.ensemble import EnsemblePosterior
from botorch.posteriors.posterior_list import PosteriorList
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
Expand All @@ -35,7 +35,10 @@ def evaluate(self, Y):

def forward(self, posterior):
return PosteriorList(
*[DeterministicPosterior(2 * p.mean + 1) for p in posterior.posteriors]
*[
EnsemblePosterior(2 * p.mean.unsqueeze(0) + 1)
for p in posterior.posteriors
]
)


Expand Down
34 changes: 0 additions & 34 deletions test/posteriors/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,3 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools

from warnings import catch_warnings

import torch
from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.utils.testing import BotorchTestCase


class TestDeterministicPosterior(BotorchTestCase):
def test_DeterministicPosterior(self):
for shape, dtype in itertools.product(
((3, 2), (2, 3, 1)), (torch.float, torch.double)
):
values = torch.randn(*shape, device=self.device, dtype=dtype)
p = DeterministicPosterior(values)
with catch_warnings(record=True) as ws:
p = DeterministicPosterior(values)
self.assertTrue(
any("marked for deprecation" in str(w.message) for w in ws)
)
self.assertEqual(p.device.type, self.device.type)
self.assertEqual(p.dtype, dtype)
self.assertEqual(p._extended_shape(), values.shape)
with self.assertRaises(NotImplementedError):
p.base_sample_shape
self.assertTrue(torch.equal(p.mean, values))
self.assertTrue(torch.equal(p.variance, torch.zeros_like(values)))
# test sampling
samples = p.rsample()
self.assertTrue(torch.equal(samples, values.unsqueeze(0)))
samples = p.rsample(torch.Size([2]))
self.assertTrue(torch.equal(samples, values.expand(2, *values.shape)))
4 changes: 2 additions & 2 deletions test/posteriors/test_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
from botorch.posteriors import GPyTorchPosterior, Posterior, PosteriorList
from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.posteriors.ensemble import EnsemblePosterior
from botorch.utils.testing import BotorchTestCase
from gpytorch.distributions import MultivariateNormal
from linear_operator.operators import to_linear_operator
Expand Down Expand Up @@ -57,7 +57,7 @@ def _make_gpytorch_posterior(self, shape, dtype):

def _make_deterministic_posterior(self, shape, dtype):
mean = torch.rand(*shape, 1, dtype=dtype, device=self.device)
return DeterministicPosterior(values=mean)
return EnsemblePosterior(values=mean.unsqueeze(0))

def test_posterior_list(self):
for dtype, use_deterministic in product(
Expand Down
19 changes: 0 additions & 19 deletions test/sampling/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,3 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from botorch.sampling.deterministic import DeterministicSampler
from botorch.utils.testing import BotorchTestCase, MockPosterior


class TestDeterministicSampler(BotorchTestCase):
def test_deterministic_sampler(self):
# Basic usage.
samples = torch.rand(1, 2)
posterior = MockPosterior(samples=samples)
sampler = DeterministicSampler(sample_shape=torch.Size([2]))
self.assertTrue(torch.equal(samples.repeat(2, 1, 1), sampler(posterior)))

# Test _update_base_samples.
sampler._update_base_samples(
posterior=posterior,
base_sampler=sampler,
)
34 changes: 17 additions & 17 deletions test/sampling/test_get_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
# LICENSE file in the root directory of this source tree.

import torch
from botorch.posteriors.deterministic import DeterministicPosterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior_list import PosteriorList
from botorch.posteriors.torch import TorchPosterior
from botorch.posteriors.transformed import TransformedPosterior
from botorch.sampling.get_sampler import get_sampler
from botorch.sampling.list_sampler import ListSampler
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.sampling.stochastic_samplers import StochasticSampler
from botorch.utils.testing import BotorchTestCase
from gpytorch.distributions import MultivariateNormal
from torch.distributions.gamma import Gamma
Expand All @@ -22,40 +20,42 @@
class TestGetSampler(BotorchTestCase):
def test_get_sampler(self):
# Basic usage w/ gpytorch posterior.
posterior = GPyTorchPosterior(
mvn_posterior = GPyTorchPosterior(
distribution=MultivariateNormal(torch.rand(2), torch.eye(2))
)
seed = 2
n_samples = 10
sampler = get_sampler(
posterior=posterior, sample_shape=torch.Size([10]), seed=2
posterior=mvn_posterior, sample_shape=torch.Size([n_samples]), seed=seed
)
self.assertIsInstance(sampler, SobolQMCNormalSampler)
self.assertEqual(sampler.seed, 2)
self.assertEqual(sampler.sample_shape, torch.Size([10]))
self.assertEqual(sampler.seed, seed)
self.assertEqual(sampler.sample_shape, torch.Size([n_samples]))

# Fallback to IID sampler.
posterior = GPyTorchPosterior(
big_mvn_posterior = GPyTorchPosterior(
distribution=MultivariateNormal(torch.rand(22000), torch.eye(22000))
)
sampler = get_sampler(posterior=posterior, sample_shape=torch.Size([10]))
sampler = get_sampler(
posterior=big_mvn_posterior, sample_shape=torch.Size([n_samples])
)
self.assertIsInstance(sampler, IIDNormalSampler)
self.assertEqual(sampler.sample_shape, torch.Size([10]))
self.assertEqual(sampler.sample_shape, torch.Size([n_samples]))

# Transformed posterior.
tf_post = TransformedPosterior(
posterior=posterior, sample_transform=lambda X: X
posterior=big_mvn_posterior, sample_transform=lambda X: X
)
sampler = get_sampler(posterior=tf_post, sample_shape=torch.Size([10]))
sampler = get_sampler(posterior=tf_post, sample_shape=torch.Size([n_samples]))
self.assertIsInstance(sampler, IIDNormalSampler)
self.assertEqual(sampler.sample_shape, torch.Size([10]))
self.assertEqual(sampler.sample_shape, torch.Size([n_samples]))

# PosteriorList with transformed & deterministic.
post_list = PosteriorList(
tf_post, DeterministicPosterior(values=torch.rand(1, 2))
)
# PosteriorList with transformed & original
post_list = PosteriorList(tf_post, mvn_posterior)
sampler = get_sampler(posterior=post_list, sample_shape=torch.Size([5]))
self.assertIsInstance(sampler, ListSampler)
self.assertIsInstance(sampler.samplers[0], IIDNormalSampler)
self.assertIsInstance(sampler.samplers[1], StochasticSampler)
self.assertIsInstance(sampler.samplers[1], SobolQMCNormalSampler)
for s in sampler.samplers:
self.assertEqual(s.sample_shape, torch.Size([5]))

Expand Down

0 comments on commit ee06209

Please sign in to comment.