diff --git a/gpx/kernels/kernels.py b/gpx/kernels/kernels.py index 02136df..7a3be57 100644 --- a/gpx/kernels/kernels.py +++ b/gpx/kernels/kernels.py @@ -558,6 +558,52 @@ def squared_exponential_kernel_d01kj( ) +def _squared_exponential_kernel_d01kjc( + x1: ArrayLike, + x2: ArrayLike, + lengthscale: ArrayLike, + jaccoef: ArrayLike, + jacobian: ArrayLike, + active_dims: ArrayLike, +) -> Array: + ns1, nf = x1.shape + ns2, _ = x2.shape + _, _, nv2 = jacobian.shape + nact = active_dims.shape[0] + z1 = x1[:, active_dims] / lengthscale + z2 = x2[:, active_dims] / lengthscale + ed2 = jnp.exp(-squared_distances(z1, z2)) + diff = (2.0 / lengthscale) * (z1[:, jnp.newaxis] - z2) + diff_j1 = jnp.einsum("stf,sf->st", diff, jaccoef[:, active_dims]) + diff_j2 = jnp.einsum("stf,tfu->stu", diff, jacobian[:, active_dims]) + d01kj = jnp.einsum("st,st,stu->stu", -ed2, diff_j1, diff_j2) + diag = ed2 * (2.0 / lengthscale**2) + diag = diag[:, :, jnp.newaxis].repeat(nact, axis=2) + diag = jnp.einsum( + "sf,stf,tfu->stu", jaccoef[:, active_dims], diag, jacobian[:, active_dims] + ) + d01kj = d01kj + diag + d01kj = d01kj.reshape(ns1, ns2 * nv2) + return d01kj + + +@jit +def squared_exponential_kernel_d01kjc( + x1: ArrayLike, + x2: ArrayLike, + params: Dict[str, Parameter], + jaccoef: ArrayLike, + jacobian: ArrayLike, + active_dims: ArrayLike = None, +) -> Array: + lengthscale = params["lengthscale"].value + if active_dims is None: + active_dims = jnp.arange(x1.shape[1]) + return _squared_exponential_kernel_d01kjc( + x1, x2, lengthscale, jaccoef, jacobian, active_dims + ) + + # ============================================================================= # Matern(1/2) Kernel # ============================================================================= @@ -969,12 +1015,12 @@ def _matern52_kernel_d01kjc( x2: ArrayLike, lengthscale: ArrayLike, jaccoef: ArrayLike, - jacobian2: ArrayLike, + jacobian: ArrayLike, active_dims: ArrayLike, ) -> Array: ns1, nf1 = x1.shape ns2, _ = x2.shape - _, _, nv2 = jacobian2.shape + _, _, nv2 = jacobian.shape nact = active_dims.shape[0] z1 = x1[:, active_dims] / lengthscale z2 = x2[:, active_dims] / lengthscale @@ -983,12 +1029,12 @@ def _matern52_kernel_d01kjc( d = jnp.sqrt(5.0) * jnp.sqrt(jnp.maximum(d2, 1e-36)) constant = (5.0 / (3.0 * lengthscale**2)) * jnp.exp(-d) diff_j1 = jnp.einsum("stf,sf->st", diff, jaccoef[:, active_dims]) - diff_j2 = jnp.einsum("stf,tfu->stu", diff, jacobian2[:, active_dims]) + diff_j2 = jnp.einsum("stf,tfu->stu", diff, jacobian[:, active_dims]) d01kj = jnp.einsum("st,st,stu->stu", -constant, diff_j1, diff_j2) diag = constant * (1.0 + d) diag = diag[:, :, jnp.newaxis].repeat(nact, axis=2) diag = jnp.einsum( - "sf,stf,tfu->stu", jaccoef[:, active_dims], diag, jacobian2[:, active_dims] + "sf,stf,tfu->stu", jaccoef[:, active_dims], diag, jacobian[:, active_dims] ) d01kj = d01kj + diag # output a square kernel, samples time variables