Skip to content

Commit

Permalink
Fixed bug product with jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
pamazzeo committed Jan 17, 2024
1 parent 3e7de17 commit 3234125
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions gpx/kernels/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@ def prod_kernels_deriv_jac(
def kernel(x1, x2, params, jacobian):
params1 = params["kernel1"]
params2 = params["kernel2"]
return kernel_func1(x1, x2, params1).repeat(3, axis=axis) * deriv_func2(
_, _, jv = jacobian.shape
return kernel_func1(x1, x2, params1).repeat(jv, axis=axis) * deriv_func2(
x1, x2, params2, jacobian
) + deriv_func1(x1, x2, params1, jacobian) * kernel_func2(
x1, x2, params2
).repeat(
3, axis=axis
jv, axis=axis
)

return kernel
Expand All @@ -206,15 +207,17 @@ def prod_kernels_deriv01_jac(
def kernel(x1, x2, params, jacobian1, jacobian2):
params1 = params["kernel1"]
params2 = params["kernel2"]
_, _, jv1 = jacobian1.shape
_, _, jv2 = jacobian2.shape
return (
kernel_func1(x1, x2, params1).repeat(3, axis=0).repeat(3, axis=-1)
kernel_func1(x1, x2, params1).repeat(jv1, axis=0).repeat(jv2, axis=-1)
* deriv01_func2(x1, x2, params2, jacobian1, jacobian2)
+ deriv0_func1(x1, x2, params1, jacobian1).repeat(3, axis=-1)
* deriv1_func2(x1, x2, params2, jacobian2).repeat(3, axis=0)
+ deriv1_func1(x1, x2, params1, jacobian2).repeat(3, axis=0)
* deriv0_func2(x1, x2, params2, jacobian1).repeat(3, axis=-1)
+ deriv0_func1(x1, x2, params1, jacobian1).repeat(jv2, axis=-1)
* deriv1_func2(x1, x2, params2, jacobian2).repeat(jv1, axis=0)
+ deriv1_func1(x1, x2, params1, jacobian2).repeat(jv1, axis=0)
* deriv0_func2(x1, x2, params2, jacobian1).repeat(jv2, axis=-1)
+ deriv01_func1(x1, x2, params1, jacobian1, jacobian2)
* kernel_func2(x1, x2, params2).repeat(3, axis=0).repeat(3, axis=-1)
* kernel_func2(x1, x2, params2).repeat(jv1, axis=0).repeat(jv2, axis=-1)
)

return kernel

0 comments on commit 3234125

Please sign in to comment.