Skip to content

Commit

Permalink
metric test
Browse files Browse the repository at this point in the history
  • Loading branch information
Dong555 committed Nov 2, 2022
1 parent a34e16f commit 4814441
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
6 changes: 5 additions & 1 deletion susiepca/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ def mse(X: jnp.ndarray, Xhat: jnp.ndarray) -> float:
RRMSE: relative root mean square error
"""
return jnp.sum((X - Xhat) ** 2) / jnp.sum(X ** 2)
if X.shape != Xhat.shape:
raise ValueError("Predicted data shape doesn't match, please check")

mse = jnp.sum((X - Xhat) ** 2) / jnp.sum(X ** 2)
return mse


def get_credset(params, rho=0.9):
Expand Down
19 changes: 12 additions & 7 deletions tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,22 @@


# define the test for MSE
@pytest.mark.parametrize("seed", 0)
def test_mse(seed):
np.random.seed(seed)
X = np.random.normal(0, 1, size=(3, 5))
Xhat = X + np.random.normal(0, 0.5, (3, 5))
expected_res = 0.11514
def test_mse():
X = np.array([[-0.58, -0.43, 0.70], [-0.50, -1.22, 0.91]])

Xhat = np.array([[-0.90, -0.50, 0.40], [-1.17, -1.47, 0.91]])
Xhat_wrongshape = np.array(
[[-0.90, -0.50, 0.40], [-1.17, -1.47, 0.91], [-0.32, 0.52, 1.36]]
)

expected_res = 0.1981
actual_res = sp.metrics.mse(X, Xhat)

assert X.shape == Xhat.shape
assert pytest.approx(expected_res) == actual_res

# with pytest.raises(Exception):
with pytest.raises(ValueError):
sp.metrics.mse(X, Xhat_wrongshape)


# define the test for credible set
Expand Down

0 comments on commit 4814441

Please sign in to comment.