Skip to content

Commit

Permalink
Do not use an MCSampler in MaxPosteriorSampling (pytorch#701)
Browse files Browse the repository at this point in the history
Summary:
By using a sampler, we always end up passing `base_samples` to downstream GPyTorch's `MultivariateNormal`, which means that this always used the `root_decomposition` to correlate base samples, rather than `zero_mean_mvn_samples`. This resulted in CIQ never being used.

Pull Request resolved: pytorch#701

Reviewed By: dme65

Differential Revision: D26539941

Pulled By: Balandat

fbshipit-source-id: 3e0957fadb48153da656303e7bbc1fabc2900b5e
  • Loading branch information
Balandat authored and facebook-github-bot committed Feb 19, 2021
1 parent 33034ea commit 2941104
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions botorch/generation/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)
from botorch.generation.utils import _flip_sub_unique
from botorch.models.model import Model
from botorch.sampling.samplers import IIDNormalSampler
from botorch.utils.sampling import batched_multinomial
from botorch.utils.transforms import standardize
from torch import Tensor
Expand Down Expand Up @@ -106,10 +105,8 @@ def forward(
if isinstance(self.objective, ScalarizedObjective):
posterior = self.objective(posterior)

sampler = IIDNormalSampler(
num_samples=num_samples, collapse_batch_dims=False, resample=True
)
samples = sampler(posterior) # num_samples x batch_shape x N x m
# num_samples x batch_shape x N x m
samples = posterior.rsample(sample_shape=torch.Size([num_samples]))
if isinstance(self.objective, ScalarizedObjective):
obj = samples.squeeze(-1) # num_samples x batch_shape x N
else:
Expand Down

0 comments on commit 2941104

Please sign in to comment.