Skip to content

Commit

Permalink
Add prior predictive sampling to PC model
Browse files Browse the repository at this point in the history
  • Loading branch information
camirmas committed Apr 5, 2023
1 parent f4382c1 commit 9ab8327
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 61 deletions.
23 changes: 22 additions & 1 deletion REStats/models/power_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,28 @@ def forward(self, x):
"""
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)

return MultivariateNormal(mean_x, covar_x)


def prior_predictive_samples(self, x, n_samples=1):
"""
Generate prior predictive samples for the given inputs.
Args:
x (torch.Tensor): Input data.
n_samples (int): Number of samples to generate.
Returns:
torch.Tensor: Prior predictive samples.
"""
# Get prior distribution
prior_dist = self.forward(x)

# Generate samples
samples = prior_dist.sample(torch.Size([n_samples]))

return samples


def fit(X_train, y_train, dims=None):
Expand Down Expand Up @@ -104,4 +125,4 @@ def predict(model, likelihood, data):
with torch.no_grad(), gpytorch.settings.fast_pred_var():
pred = likelihood(model(data))

return pred
return pred
166 changes: 110 additions & 56 deletions REStats/notebooks/power_curve.ipynb

Large diffs are not rendered by default.

46 changes: 42 additions & 4 deletions tests/models/test_power_curve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import torch
import pytest
import gpytorch
from gpytorch.distributions import MultivariateNormal

from REStats.models.power_curve import ExactGPModel, fit, predict
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.distributions import MultivariateNormal
Expand All @@ -12,11 +15,46 @@ def dummy_data():
return train_x, train_y


def test_exact_gp_model(dummy_data):
train_x, train_y = dummy_data
likelihood = GaussianLikelihood()

def test_exact_gp_model_init():
train_x = torch.randn(10)
train_y = torch.randn(10)
likelihood = gpytorch.likelihoods.GaussianLikelihood()

model = ExactGPModel(train_x, train_y, likelihood)
assert isinstance(model, ExactGPModel)

assert isinstance(model.mean_module, gpytorch.means.ConstantMean)
assert isinstance(model.covar_module, gpytorch.kernels.ScaleKernel)
assert isinstance(model.covar_module.base_kernel, gpytorch.kernels.RBFKernel)


def test_exact_gp_model_forward():
train_x = torch.randn(10)
train_y = torch.randn(10)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)

# Set the model to evaluation mode
model.eval()

x = torch.randn(5)
output = model(x)

assert isinstance(output, MultivariateNormal)
assert output.mean.shape == (5,)
assert output.covariance_matrix.shape == (5, 5)


def test_prior_predictive_samples():
train_x = torch.randn(10)
train_y = torch.randn(10)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)

x = torch.linspace(-5, 5, 100)
samples = model.prior_predictive_samples(x, n_samples=5)

assert samples.shape == (5, 100)


def test_fit(dummy_data):
Expand Down

0 comments on commit 9ab8327

Please sign in to comment.