Skip to content

Commit

Permalink
Merge pull request deepchem#1786 from peastman/loss
Browse files Browse the repository at this point in the history
Be more tolerant of dtypes when computing loss
  • Loading branch information
Bharath Ramsundar authored Mar 26, 2020
2 parents f310f0a + 673e880 commit 8566ed9
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions deepchem/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class L1Loss(Loss):

def __call__(self, output, labels):
output, labels = _make_shapes_consistent(output, labels)
output, labels = _ensure_float(output, labels)
return tf.abs(output - labels)


Expand All @@ -39,6 +40,7 @@ class L2Loss(Loss):

def __call__(self, output, labels):
output, labels = _make_shapes_consistent(output, labels)
output, labels = _ensure_float(output, labels)
return tf.square(output - labels)


Expand All @@ -63,6 +65,7 @@ class BinaryCrossEntropy(Loss):

def __call__(self, output, labels):
output, labels = _make_shapes_consistent(output, labels)
output, labels = _ensure_float(output, labels)
return tf.keras.losses.binary_crossentropy(labels, output)


Expand All @@ -76,6 +79,7 @@ class CategoricalCrossEntropy(Loss):

def __call__(self, output, labels):
output, labels = _make_shapes_consistent(output, labels)
output, labels = _ensure_float(output, labels)
return tf.keras.losses.categorical_crossentropy(labels, output)


Expand All @@ -89,6 +93,7 @@ class SigmoidCrossEntropy(Loss):

def __call__(self, output, labels):
output, labels = _make_shapes_consistent(output, labels)
output, labels = _ensure_float(output, labels)
return tf.nn.sigmoid_cross_entropy_with_logits(labels, output)


Expand All @@ -103,6 +108,7 @@ class SoftmaxCrossEntropy(Loss):

def __call__(self, output, labels):
output, labels = _make_shapes_consistent(output, labels)
output, labels = _ensure_float(output, labels)
return tf.nn.softmax_cross_entropy_with_logits(labels, output)


Expand Down Expand Up @@ -142,3 +148,11 @@ def _make_shapes_consistent(output, labels):
return (output, labels)
raise ValueError("Incompatible shapes for outputs and labels: %s versus %s" %
(str(shape1), str(shape2)))

def _ensure_float(output, labels):
"""Make sure the outputs and labels are both floating point types."""
if output.dtype not in (tf.float32, tf.float64):
output = tf.cast(output, tf.float32)
if labels.dtype not in (tf.float32, tf.float64):
labels = tf.cast(labels, tf.float32)
return (output, labels)

0 comments on commit 8566ed9

Please sign in to comment.