Skip to content

Commit

Permalink
FIX: compile triplet loss within keras model (tensorflow#298)
Browse files Browse the repository at this point in the history
* FIX: compile triplet loss within keras model

* remove dummy assertion

* add a testcase when the shape of y_true is invalid

* remove testcase of invalid shape
  • Loading branch information
WindQAQ authored Jun 17, 2019
1 parent 25cf38e commit 229344f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 1 addition & 2 deletions tensorflow_addons/losses/triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,8 @@ def triplet_semihard_loss(y_true, y_pred, margin=1.0):
margin: Float, margin term in the loss definition.
"""
labels, embeddings = y_true, y_pred
# Reshape [batch_size] label tensor to a [batch_size, 1] label tensor.
# Reshape label tensor to [batch_size, 1].
lshape = tf.shape(labels)
assert lshape.shape == 1
labels = tf.reshape(labels, [lshape[0], 1])

# Build pairwise squared distance matrix.
Expand Down
7 changes: 7 additions & 0 deletions tensorflow_addons/losses/triplet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def test_unweighted(self):
loss = cce_obj(y_true, y_pred)
self.assertAlmostEqual(self.evaluate(loss), loss_np, 3)

def test_keras_model_compile(self):
model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(784,)),
tf.keras.layers.Dense(10),
])
model.compile(loss="triplet_semihard_loss", optimizer="adam")


if __name__ == '__main__':
tf.test.main()

0 comments on commit 229344f

Please sign in to comment.