Skip to content

Commit

Permalink
Fix multiple issues with ivy.jac (ivy-llc#20973)
Browse files Browse the repository at this point in the history
  • Loading branch information
somaia02 authored Aug 11, 2023
1 parent 1fe8b5f commit 4a9a833
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 11 deletions.
7 changes: 6 additions & 1 deletion ivy/functional/backends/jax/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,15 @@ def stop_gradient(


def jac(func: Callable):
grad_fn = lambda x_in: ivy.to_native(func(x_in), nested=True)
grad_fn = lambda x_in: ivy.to_native(
func(ivy.to_ivy(x_in, nested=True)),
nested=True,
include_derived=True,
)
callback_fn = lambda x_in: ivy.to_ivy(
jax.jacfwd(grad_fn)((ivy.to_native(x_in, nested=True))),
nested=True,
include_derived=True,
)
return callback_fn

Expand Down
24 changes: 21 additions & 3 deletions ivy/functional/backends/paddle/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _get_jac_one_arg_fn(grad_fn, xs, out_idx):

def one_arg_fn(x):
idx = next(nested_indices)
new_xs = ivy.set_nest_at_index(xs, idx, x) if idx else x
new_xs = ivy.set_nest_at_index(xs, idx, x, shallow=False) if idx else x
ret = grad_fn(new_xs)
for i in out_idx:
ret = ret[i]
Expand All @@ -212,19 +212,37 @@ def _get_one_out_fn(grad_fn, xs, fn_ret):

def one_out_fn(o):
out_idx = next(out_nested_indices)
out_shape = ivy.index_nest(grad_fn(xs), out_idx).shape
one_arg_fn = _get_jac_one_arg_fn(grad_fn, xs, out_idx)
jacobian = ivy.nested_map(
xs,
lambda x: paddle.incubate.autograd.Jacobian(one_arg_fn, ivy.to_native(x)),
lambda x: jacobian_to_ivy(
paddle.incubate.autograd.Jacobian(
one_arg_fn, ivy.to_native(x.expand_dims())
),
x.shape,
out_shape,
),
shallow=False,
)
return jacobian

return one_out_fn


def jacobian_to_ivy(jacobian, in_shape, out_shape):
jac_ivy = ivy.to_ivy(jacobian[:])
jac_shape = out_shape + in_shape
jac_reshaped = jac_ivy.reshape(jac_shape)
return jac_reshaped


def jac(func: Callable):
grad_fn = lambda x_in: ivy.to_native(func(x_in), nested=True)
grad_fn = lambda x_in: ivy.to_native(
func(ivy.to_ivy(x_in, nested=True)),
nested=True,
include_derived=True,
)

def callback_fn(xs):
fn_ret = grad_fn(xs)
Expand Down
8 changes: 7 additions & 1 deletion ivy/functional/backends/tensorflow/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,15 @@ def stop_gradient(


def jac(func: Callable):
grad_fn = lambda x_in: ivy.to_native(func(x_in), nested=True)
grad_fn = lambda x_in: ivy.to_native(
func(ivy.to_ivy(x_in, nested=True)),
nested=True,
include_derived=True,
)

def callback_fn(x_in):
with tf.GradientTape(persistent=True) as tape:
ivy.nested_map(x_in, ivy.copy_array)
x_in = ivy.to_native(x_in, nested=True)
tape.watch(x_in)
y = grad_fn(x_in)
Expand All @@ -178,6 +183,7 @@ def callback_fn(x_in):
tape.jacobian(yi, x_in, unconnected_gradients="zero"),
nested=True,
),
include_derived=True,
)
else:
jacobian = ivy.to_ivy(tape.jacobian(y, x_in))
Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/backends/torch/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,15 @@ def stop_gradient(


def jac(func: Callable):
grad_fn = lambda x_in: ivy.to_native(func(x_in), nested=True)
grad_fn = lambda x_in: ivy.to_native(
func(ivy.to_ivy(x_in, nested=True)),
nested=True,
include_derived=True,
)
callback_fn = lambda x_in: ivy.to_ivy(
torch.func.jacfwd(grad_fn)((ivy.to_native(x_in, nested=True))),
nested=True,
include_derived=True,
)
return callback_fn

Expand Down
35 changes: 30 additions & 5 deletions ivy_tests/test_ivy/test_functional/test_core/test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,23 +187,26 @@ def test_value_and_grad(x, dtype, func, backend_fw):
"x", [[[4.6, 2.1, 5], [2.8, 1.3, 6.2]], [[4.6, 2.1], [5, 2.8], [1.3, 6.2]]]
)
@pytest.mark.parametrize("dtype", ["float32", "float64"])
@pytest.mark.parametrize(
"func", [lambda x: ivy.mean(ivy.square(x)), lambda x: ivy.mean(ivy.cos(x))]
)
def test_jac(x, dtype, func, backend_fw):
@pytest.mark.parametrize("func_str", ["square", "cos"])
def test_jac(x, dtype, func_str, backend_fw):
if backend_fw == "numpy":
pytest.skip()

with update_backend(backend_fw) as ivy_backend:
f = ivy_backend.__dict__[func_str]
func = lambda x: ivy_backend.mean(f(x))
_variable_fn = ivy_backend.ivy.functional.ivy.gradients._variable
var = _variable_fn(ivy_backend.array(x, dtype=dtype))
fn = ivy_backend.jac(func)
jacobian = fn(var)
jacobian_np = helpers.flatten_and_to_np(ret=jacobian, backend=backend_fw)
assert jacobian_np != []

with update_backend("tensorflow") as gt_backend:
f = gt_backend.__dict__[func_str]
func = lambda x: gt_backend.mean(f(x))
_variable_fn = gt_backend.ivy.functional.ivy.gradients._variable
var = _variable_fn(ivy.array(x, dtype=dtype))
var = _variable_fn(gt_backend.array(x, dtype=dtype))
fn = gt_backend.jac(func)
jacobian_gt = fn(var)
jacobian_np_from_gt = helpers.flatten_and_to_np(
Expand Down Expand Up @@ -240,6 +243,28 @@ def test_jac(x, dtype, func, backend_fw):
assert jacobian.shape == jacobian_from_gt.shape
assert np.allclose(jacobian, jacobian_from_gt)

# Test func with non 0-dim output
with update_backend(backend_fw) as ivy_backend:
func = ivy_backend.__dict__[func_str]
_variable_fn = ivy_backend.ivy.functional.ivy.gradients._variable
var = _variable_fn(ivy_backend.array(x, dtype=dtype))
fn = ivy_backend.jac(func)
jacobian = fn(var)
jacobian_np = helpers.flatten_and_to_np(ret=jacobian, backend=backend_fw)

with update_backend("tensorflow") as gt_backend:
func = gt_backend.__dict__[func_str]
_variable_fn = gt_backend.ivy.functional.ivy.gradients._variable
var = _variable_fn(gt_backend.array(x, dtype=dtype))
fn = gt_backend.jac(func)
jacobian_gt = fn(var)
jacobian_np_from_gt = helpers.flatten_and_to_np(
ret=jacobian_gt, backend="tensorflow"
)
for jacobian, jacobian_from_gt in zip(jacobian_np, jacobian_np_from_gt):
assert jacobian.shape == jacobian_from_gt.shape
assert np.allclose(jacobian, jacobian_from_gt)


# grad
@pytest.mark.parametrize(
Expand Down

0 comments on commit 4a9a833

Please sign in to comment.