Skip to content

Commit

Permalink
fixes to get some meta tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
mattbarrett98 committed Jun 15, 2022
1 parent de19394 commit 6ffd32f
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 5 deletions.
4 changes: 2 additions & 2 deletions ivy/container/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def prod(
) -> ivy.Container:
return self.handle_inplace(
self.map(
lambda x_, _: ivy.prod(x_, axis, keepdims, dtype=dtype)
lambda x_, _: ivy.prod(x_, axis=axis, keepdims=keepdims, dtype=dtype)
if ivy.is_array(x_)
else x_,
key_chains,
Expand All @@ -134,7 +134,7 @@ def sum(
) -> ivy.Container:
return self.handle_inplace(
self.map(
lambda x_, _: ivy.sum(x_, axis, dtype, keepdims)
lambda x_, _: ivy.sum(x_, axis=axis, dtype=dtype, keepdims=keepdims)
if ivy.is_array(x_)
else x_,
key_chains,
Expand Down
2 changes: 2 additions & 0 deletions ivy/functional/backends/tensorflow/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def add(
promoted_type = tf.experimental.numpy.promote_types(x1.dtype, x2.dtype)
x1 = tf.cast(x1, promoted_type)
x2 = tf.cast(x2, promoted_type)
elif not isinstance(x1, tf.Tensor):
x1 = tf.constant(x1, dtype=x2.dtype)
return tf.add(x1, x2)


Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/tensorflow/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def is_native_array(x, exclusive=False):
if isinstance(x, tf.Tensor):
if isinstance(x, tf.Tensor) or isinstance(x, tf.Variable):
if exclusive and isinstance(x, tf.Variable):
return False
return True
Expand Down
2 changes: 2 additions & 0 deletions ivy/functional/backends/torch/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ def subtract(
x1 = x1.to(promoted_type)
x2 = x2.to(promoted_type)
return torch.subtract(x1, x2, out=out)
elif not isinstance(x1, torch.Tensor):
x1 = torch.tensor(x1, dtype=x2.dtype)
return torch.subtract(x1, x2)


Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/ivy/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _compute_cost_and_update_grads(
retain_grads=False,
)
if batched:
inner_grads = inner_grads * num_tasks
inner_grads = ivy.multiply(inner_grads, num_tasks)
if average_across_steps_or_final:
all_grads.append(inner_grads)
else:
Expand Down Expand Up @@ -85,7 +85,7 @@ def _train_task(
retain_grads=order > 1,
)
if batched:
inner_update_grads = inner_update_grads * num_tasks
inner_update_grads = ivy.multiply(inner_update_grads, num_tasks)

# compute the cost to be optimized, and update all_grads if fist order method
if outer_cost_fn is None and not unique_inner and not unique_outer:
Expand Down

0 comments on commit 6ffd32f

Please sign in to comment.