Skip to content

Commit

Permalink
added default arguments for the iterative calculation of the log marg…
Browse files Browse the repository at this point in the history
…inal likelihood + now iterative automatically calls a loss function with the keyword arugment iterative=True
  • Loading branch information
ecignoni committed Dec 15, 2023
1 parent c8b1a5a commit 1850509
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 89 deletions.
18 changes: 18 additions & 0 deletions gpx/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from collections import namedtuple

import jax

GPX_DEFAULTS = {
# number of steps in stochastic trace estimation
"num_evals": 10,
# number of lanczos evaluations
"num_lanczos": 8,
# default "random" key for lanczos
"lanczos_key": jax.random.PRNGKey(2023),
}

gpxargs = namedtuple(
"GPX_DEFAULTS_ARGUMENTS",
GPX_DEFAULTS.keys(),
defaults=GPX_DEFAULTS.values(),
)()
48 changes: 31 additions & 17 deletions gpx/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jax.typing import ArrayLike
from typing_extensions import Self

from ..defaults import gpxargs
from ..optimizers import NLoptWrapper, scipy_minimize, scipy_minimize_derivs
from ..parameters import ModelState, Parameter
from .utils import (
Expand Down Expand Up @@ -196,10 +197,10 @@ def log_marginal_likelihood(
x: ArrayLike,
y: ArrayLike,
return_negative: Optional[bool] = False,
iterative=False,
num_evals=None,
num_lanczos=None,
lanczos_key=None,
iterative: Optional[bool] = False,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[ArrayLike] = gpxargs.lanczos_key,
) -> Array:
"""Computes the log marginal likelihood.
Expand Down Expand Up @@ -235,10 +236,10 @@ def log_marginal_likelihood_derivs(
y: ArrayLike,
jacobian: ArrayLike,
return_negative: Optional[bool] = False,
iterative=False,
num_evals=None,
num_lanczos=None,
lanczos_key=None,
iterative: Optional[bool] = False,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[ArrayLike] = gpxargs.lanczos_key,
) -> Array:
"""Computes the log marginal likelihood using the hessian kernel.
Expand Down Expand Up @@ -287,10 +288,10 @@ def fit(
y: ArrayLike,
minimize: Optional[bool] = True,
num_restarts: Optional[int] = 0,
key: prng.PRNGKeyArray = None,
key: Optional[ArrayLike] = None,
return_history: Optional[bool] = False,
iterative: Optional[bool] = False,
loss_kwargs=None,
loss_kwargs: Optional[Dict] = None,
) -> Self:
"""fits the model
Expand Down Expand Up @@ -329,6 +330,11 @@ def fit(
In order to optimize with randomized restarts you need to provide a valid
JAX PRNGKey.
"""
# we tell the loss that it should be iterative
# note that this overrides an eventual 'iterative' keyword
if loss_kwargs is None:
loss_kwargs = {}
loss_kwargs["iterative"] = iterative
loss_fn = loss_fn_with_args(self.state.loss_fn, loss_kwargs)

if minimize:
Expand Down Expand Up @@ -369,10 +375,10 @@ def fit_derivs(
jacobian: ArrayLike,
minimize: Optional[bool] = True,
num_restarts: Optional[int] = 0,
key: prng.PRNGKeyArray = None,
key: Optional[ArrayLike] = None,
return_history: Optional[bool] = False,
iterative: Optional[bool] = False,
loss_kwargs=None,
loss_kwargs: Optional[Dict] = None,
) -> Self:
"""fits the model
Expand Down Expand Up @@ -415,6 +421,10 @@ def fit_derivs(
In order to optimize with randomized restarts you need to provide a valid
JAX PRNGKey.
"""
# we tell the loss that it should be iterative
if loss_kwargs is None:
loss_kwargs = {}
loss_kwargs["iterative"] = iterative
loss_fn = loss_fn_with_args(self.state.loss_fn, loss_kwargs)

if minimize:
Expand Down Expand Up @@ -460,13 +470,17 @@ def fit_nlopt(
x: ArrayLike,
y: ArrayLike,
opt: NLoptWrapper,
minimize=True,
key=None,
num_restarts=0,
return_history=False,
minimize: Optional[bool] = True,
key: Optional[ArrayLike] = None,
num_restarts: Optional[int] = 0,
return_history: Optional[bool] = False,
iterative: Optional[bool] = False,
loss_kwargs=None,
loss_kwargs: Optional[Dict] = None,
) -> Self:
# we tell the loss that it should be iterative
if loss_kwargs is None:
loss_kwargs = {}
loss_kwargs["iterative"] = iterative
loss_fn = loss_fn_with_args(self.state.loss_fn, loss_kwargs)

if minimize:
Expand Down
49 changes: 25 additions & 24 deletions gpx/models/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jax.typing import ArrayLike

from ..bijectors import Softplus
from ..defaults import gpxargs
from ..mean_functions import data_mean, zero_mean
from ..parameters import ModelState
from ..parameters.parameter import Parameter, is_parameter
Expand Down Expand Up @@ -49,9 +50,9 @@ def log_marginal_likelihood(
x: ArrayLike,
y: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"""computes the log marginal likelihood for standard gaussian process
Expand Down Expand Up @@ -97,9 +98,9 @@ def log_marginal_likelihood_derivs(
y: ArrayLike,
jacobian: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"""computes the log marginal likelihood for standard gaussian process
using the Hessian kernel
Expand Down Expand Up @@ -158,9 +159,9 @@ def log_posterior(
x: ArrayLike,
y: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"""Computes the log posterior
Expand All @@ -186,9 +187,9 @@ def log_posterior_derivs(
y: ArrayLike,
jacobian: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"""Computes the log posterior
Expand All @@ -214,9 +215,9 @@ def neg_log_marginal_likelihood(
x: ArrayLike,
y: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"Returns the negative log marginal likelihood"
return -log_marginal_likelihood(
Expand All @@ -236,9 +237,9 @@ def neg_log_marginal_likelihood_derivs(
y: ArrayLike,
jacobian: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"Returns the negative log marginal likelihood"
return -log_marginal_likelihood_derivs(
Expand All @@ -258,9 +259,9 @@ def neg_log_posterior(
x: ArrayLike,
y: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"Returns the negative log posterior"
return -log_posterior(
Expand All @@ -280,9 +281,9 @@ def neg_log_posterior_derivs(
y: ArrayLike,
jacobian: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"Returns the negative log posterior"
return -log_posterior_derivs(
Expand Down
49 changes: 25 additions & 24 deletions gpx/models/sgpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jax.typing import ArrayLike

from ..bijectors import Softplus
from ..defaults import gpxargs
from ..mean_functions import data_mean, zero_mean
from ..parameters.model_state import ModelState
from ..parameters.parameter import Parameter, is_parameter
Expand Down Expand Up @@ -45,9 +46,9 @@ def log_marginal_likelihood(
x: ArrayLike,
y: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"""log marginal likelihood for SGPR (projected processes)
Expand Down Expand Up @@ -90,9 +91,9 @@ def log_marginal_likelihood_derivs(
y: ArrayLike,
jacobian: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"""log marginal likelihood for SGPR (projected processes)
Expand Down Expand Up @@ -149,9 +150,9 @@ def log_posterior(
x: ArrayLike,
y: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"""Computes the log posterior
Expand All @@ -176,9 +177,9 @@ def log_posterior_derivs(
y: ArrayLike,
jacobian: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"""Computes the log posterior
Expand All @@ -204,9 +205,9 @@ def neg_log_marginal_likelihood(
x: ArrayLike,
y: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"Returns the negative log marginal likelihood"
return -log_marginal_likelihood(
Expand All @@ -226,9 +227,9 @@ def neg_log_marginal_likelihood_derivs(
y: ArrayLike,
jacobian: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"Returns the negative log marginal likelihood"
return -log_marginal_likelihood_derivs(
Expand All @@ -248,9 +249,9 @@ def neg_log_posterior(
x: ArrayLike,
y: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"Returns the negative log posterior"
return -log_posterior(
Expand All @@ -270,9 +271,9 @@ def neg_log_posterior_derivs(
y: ArrayLike,
jacobian: ArrayLike,
iterative: Optional[bool] = False,
num_evals: Optional[int] = None,
num_lanczos: Optional[int] = None,
lanczos_key: Optional[prng.PRNGKeyArray] = None,
num_evals: Optional[int] = gpxargs.num_evals,
num_lanczos: Optional[int] = gpxargs.num_lanczos,
lanczos_key: Optional[prng.PRNGKeyArray] = gpxargs.lanczos_key,
) -> Array:
"Returns the negative log posterior"
return -log_posterior_derivs(
Expand Down
Loading

0 comments on commit 1850509

Please sign in to comment.