Skip to content

Commit

Permalink
faster d01kjc for squared exponential kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
ecignoni committed Apr 23, 2024
1 parent 5521ae1 commit 623236c
Showing 1 changed file with 50 additions and 4 deletions.
54 changes: 50 additions & 4 deletions gpx/kernels/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,52 @@ def squared_exponential_kernel_d01kj(
)


def _squared_exponential_kernel_d01kjc(
x1: ArrayLike,
x2: ArrayLike,
lengthscale: ArrayLike,
jaccoef: ArrayLike,
jacobian: ArrayLike,
active_dims: ArrayLike,
) -> Array:
ns1, nf = x1.shape
ns2, _ = x2.shape
_, _, nv2 = jacobian.shape
nact = active_dims.shape[0]
z1 = x1[:, active_dims] / lengthscale
z2 = x2[:, active_dims] / lengthscale
ed2 = jnp.exp(-squared_distances(z1, z2))
diff = (2.0 / lengthscale) * (z1[:, jnp.newaxis] - z2)
diff_j1 = jnp.einsum("stf,sf->st", diff, jaccoef[:, active_dims])
diff_j2 = jnp.einsum("stf,tfu->stu", diff, jacobian[:, active_dims])
d01kj = jnp.einsum("st,st,stu->stu", -ed2, diff_j1, diff_j2)
diag = ed2 * (2.0 / lengthscale**2)
diag = diag[:, :, jnp.newaxis].repeat(nact, axis=2)
diag = jnp.einsum(
"sf,stf,tfu->stu", jaccoef[:, active_dims], diag, jacobian[:, active_dims]
)
d01kj = d01kj + diag
d01kj = d01kj.reshape(ns1, ns2 * nv2)
return d01kj


@jit
def squared_exponential_kernel_d01kjc(
x1: ArrayLike,
x2: ArrayLike,
params: Dict[str, Parameter],
jaccoef: ArrayLike,
jacobian: ArrayLike,
active_dims: ArrayLike = None,
) -> Array:
lengthscale = params["lengthscale"].value
if active_dims is None:
active_dims = jnp.arange(x1.shape[1])
return _squared_exponential_kernel_d01kjc(
x1, x2, lengthscale, jaccoef, jacobian, active_dims
)


# =============================================================================
# Matern(1/2) Kernel
# =============================================================================
Expand Down Expand Up @@ -969,12 +1015,12 @@ def _matern52_kernel_d01kjc(
x2: ArrayLike,
lengthscale: ArrayLike,
jaccoef: ArrayLike,
jacobian2: ArrayLike,
jacobian: ArrayLike,
active_dims: ArrayLike,
) -> Array:
ns1, nf1 = x1.shape
ns2, _ = x2.shape
_, _, nv2 = jacobian2.shape
_, _, nv2 = jacobian.shape
nact = active_dims.shape[0]
z1 = x1[:, active_dims] / lengthscale
z2 = x2[:, active_dims] / lengthscale
Expand All @@ -983,12 +1029,12 @@ def _matern52_kernel_d01kjc(
d = jnp.sqrt(5.0) * jnp.sqrt(jnp.maximum(d2, 1e-36))
constant = (5.0 / (3.0 * lengthscale**2)) * jnp.exp(-d)
diff_j1 = jnp.einsum("stf,sf->st", diff, jaccoef[:, active_dims])
diff_j2 = jnp.einsum("stf,tfu->stu", diff, jacobian2[:, active_dims])
diff_j2 = jnp.einsum("stf,tfu->stu", diff, jacobian[:, active_dims])
d01kj = jnp.einsum("st,st,stu->stu", -constant, diff_j1, diff_j2)
diag = constant * (1.0 + d)
diag = diag[:, :, jnp.newaxis].repeat(nact, axis=2)
diag = jnp.einsum(
"sf,stf,tfu->stu", jaccoef[:, active_dims], diag, jacobian2[:, active_dims]
"sf,stf,tfu->stu", jaccoef[:, active_dims], diag, jacobian[:, active_dims]
)
d01kj = d01kj + diag
# output a square kernel, samples time variables
Expand Down

0 comments on commit 623236c

Please sign in to comment.