Skip to content

Commit

Permalink
each kernel class now supports d01kjc. Prod and Sum kernels are still…
Browse files Browse the repository at this point in the history
… problematic in the sense that they neede specialized functions.
  • Loading branch information
ecignoni committed Apr 23, 2024
1 parent 623236c commit ae56da6
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions gpx/kernels/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,21 @@ def __init__(self, active_dims: ArrayLike = None) -> None:
active_dims=active_dims,
)

# functions accepting a 0-jacobian already contracted
# with the regression coefficients
self.d0kjc = partial(
grad_kernelize(
argnums=0, with_jacob=True, with_jaccoef=True, trace_samples=False
)(self._kernel_base),
active_dims=active_dims,
)
self.d01kjc = partial(
grad_kernelize(
argnums=(0, 1), with_jacob=True, with_jaccoef=True, trace_samples=False
)(self._kernel_base),
active_dims=active_dims,
)

def __call__(self, x1: ArrayLike, x2: ArrayLike, params: Dict) -> Array:
return self.k(x1, x2, params)

Expand Down Expand Up @@ -1324,6 +1339,11 @@ def __init__(self, active_dims: ArrayLike = None) -> None:
# faster hessian-jacobian
self.d01kj = partial(squared_exponential_kernel_d01kj, active_dims=active_dims)

# faster hessian-jaccoef
self.d01kjc = partial(
squared_exponential_kernel_d01kjc, active_dims=active_dims
)

def default_params(self):
return dict(
lengthscale=Parameter(
Expand Down Expand Up @@ -1417,6 +1437,9 @@ def __init__(self, active_dims: ArrayLike = None) -> None:
self.d1kj = partial(matern52_kernel_d1kj, active_dims=active_dims)
self.d01kj = partial(matern52_kernel_d01kj, active_dims=active_dims)

# faster hessian-jaccoef
self.d01kjc = partial(matern52_kernel_d01kjc, active_dims=active_dims)

def default_params(self):
return dict(
lengthscale=Parameter(
Expand Down Expand Up @@ -1510,6 +1533,9 @@ def __init__(self, kernel1: Kernel, kernel2: Kernel) -> None:
self.d1kj = sum_kernels_jac(kernel1.d1kj, kernel2.d1kj)
self.d01kj = sum_kernels_jac2(kernel1.d01kj, kernel2.d01kj)

# TODO: we need a dedicated function for the d0kjc and d01kjc
# until that moment, it will use the autodifferentiated function.

def default_params(self):
# simply delegate
return {
Expand Down Expand Up @@ -1586,6 +1612,9 @@ def __init__(self, kernel1: Kernel, kernel2: Kernel) -> None:
kernel2.d01kj,
)

# TODO: we need a dedicated function for the d0kjc and d01kjc
# until that moment, it will use the autodifferentiated function.

def default_params(self):
# simply delegate
return {
Expand Down

0 comments on commit ae56da6

Please sign in to comment.