Skip to content

Commit

Permalink
Fix tests that used numpy backend instead of tested backend
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Mar 8, 2024
1 parent f464acd commit 422c09a
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions test/test_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,12 @@ def test_shape_reduce(backend):

assert einx.logsumexp("a [...]", x).shape == (16,)

assert einx.logsumexp("[a]", [0.0, 1.0]).shape == ()
assert einx.logsumexp("[a]", [np.asarray(0.0), np.asarray(1.0)]).shape == ()
assert einx.mean("[a]", [backend.to_tensor(0.0), np.asarray(1.0)]).shape == ()
assert einx.sum("[a]", [backend.to_tensor(0.0), backend.to_tensor(1.0)]).shape == ()
assert einx.logsumexp("[a] 1", [[0.0], [1.0]]).shape == (1,)
assert einx.logsumexp("[a]", [0.0] * 10).shape == ()
assert einx.logsumexp("[a]", [0.0, 1.0], backend=backend).shape == ()
assert einx.sum("[a]", [backend.to_tensor(0.0), backend.to_tensor(1.0)], backend=backend).shape == ()
assert einx.logsumexp("[a] 1", [[0.0], [1.0]], backend=backend).shape == (1,)
assert einx.logsumexp("[a]", [0.0] * 10, backend=backend).shape == ()
with pytest.raises(ValueError):
einx.logsumexp("a", [0.0, [1.0]])
einx.logsumexp("a", [0.0, [1.0]], backend=backend)

@pytest.mark.parametrize("backend", backends)
def test_shape_elementwise(backend):
Expand Down

0 comments on commit 422c09a

Please sign in to comment.