Skip to content

Commit

Permalink
Enable custom metrics in evaluate.py (blei-lab#809)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth-agrawal authored and dustinvtran committed Dec 30, 2017
1 parent 4d8c1f3 commit a58e159
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
6 changes: 5 additions & 1 deletion edward/criticisms/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ def evaluate(metrics, data, n_samples=500, output_key=None, seed=None):
sess = get_session()
if isinstance(metrics, str):
metrics = [metrics]
elif callable(metrics):
metrics = [metrics]
elif not isinstance(metrics, list):
raise TypeError("metrics must have type str or list.")
raise TypeError("metrics must have type str or list, or be callable.")

check_data(data)
if not isinstance(n_samples, int):
Expand Down Expand Up @@ -218,6 +220,8 @@ def evaluate(metrics, data, n_samples=500, output_key=None, seed=None):
log_pred = [sess.run(tensor, feed_dict) for _ in range(n_samples)]
log_pred = tf.add_n(log_pred) / tf.cast(n_samples, tensor.dtype)
evaluations += [log_pred]
elif callable(metric):
evaluations += [metric(y_true, y_pred, **params)]
else:
raise NotImplementedError("Metric is not implemented: {}".format(metric))

Expand Down
13 changes: 13 additions & 0 deletions tests/criticisms/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,18 @@ def test_output_key(self):
{x: x_data, y: y_data, x_ph: x_ph_data}, n_samples=1,
output_key='x')

def test_custom_metric(self):
def logcosh(y_true, y_pred):
diff = y_pred - y_true
return tf.reduce_mean(diff + tf.nn.softplus(-2.0 * diff) - tf.log(2.0),
axis=-1)
with self.test_session():
x = Normal(loc=0.0, scale=1.0)
x_data = tf.constant(0.0)
ed.evaluate(logcosh, {x: x_data}, n_samples=1)
ed.evaluate(['mean_squared_error', logcosh], {x: x_data}, n_samples=1)
self.assertRaises(NotImplementedError, ed.evaluate, 'logcosh',
{x: x_data}, n_samples=1)

if __name__ == '__main__':
tf.test.main()

0 comments on commit a58e159

Please sign in to comment.