Skip to content

Commit

Permalink
Merge pull request deepchem#1118 from abster12/master
Browse files Browse the repository at this point in the history
Adding Hinge Loss Layer
  • Loading branch information
lilleswing authored Mar 8, 2018
2 parents 38e109d + db699dd commit d5282b3
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 1 deletion.
33 changes: 33 additions & 0 deletions deepchem/models/tensorgraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,6 +1542,12 @@ def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):


class SparseSoftMaxCrossEntropy(Layer):
"""Computes Sparse softmax cross entropy between logits and labels.
labels: Tensor of shape [d_0,d_1,....,d_{r-1}](where r is rank of logits) and must be of dtype int32 or int64.
logits: Unscaled log probabilities of shape [d_0,....d{r-1},num_classes] and of dtype float32 or float64.
Note: the rank of the logits should be 1 greater than that of labels.
The output will be a tensor of same shape as labels and of same type as logits with the loss.
"""

def __init__(self, in_layers=None, **kwargs):
super(SparseSoftMaxCrossEntropy, self).__init__(in_layers, **kwargs)
Expand Down Expand Up @@ -4309,3 +4315,30 @@ def batch_mat_mult(self, A, B):
result = tf.matmul(A_reshape, B)
result = tf.reshape(result, tf.stack([A_shape[0], A_shape[1], axis_2]))
return result


class Hingeloss(Layer):
"""This layer computes the hinge loss on inputs:[labels,logits]
labels: The values of this tensor is expected to be 1.0 or 0.0. The shape should be the same as logits.
logits: Holds the log probabilities for labels, a float tensor.
The output is a weighted loss tensor of same shape as labels.
"""

def __init__(self, in_layers=None, **kwargs):
super(Hingeloss, self).__init__(in_layers, **kwargs)
try:
self._shape = self.in_layers[1].shape
except:
pass

def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
inputs = self._get_input_tensors(in_layers)
if len(inputs) != 2:
raise ValueError()
labels, logits = inputs[0], inputs[1]
reduction = tf.losses.Reduction
out_tensor = tf.losses.hinge_loss(
labels=labels, logits=logits, reduction=reduction.NONE)
if set_tensors:
self.out_tensor = out_tensor
return out_tensor
27 changes: 27 additions & 0 deletions deepchem/models/tensorgraph/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from deepchem.models.tensorgraph.layers import Gather
from deepchem.models.tensorgraph.layers import GraphConv
from deepchem.models.tensorgraph.layers import GraphGather
from deepchem.models.tensorgraph.layers import Hingeloss
from deepchem.models.tensorgraph.layers import Input
from deepchem.models.tensorgraph.layers import InputFifoQueue
from deepchem.models.tensorgraph.layers import InteratomicL2Distances
Expand All @@ -45,6 +46,7 @@
from deepchem.models.tensorgraph.layers import SigmoidCrossEntropy
from deepchem.models.tensorgraph.layers import SoftMax
from deepchem.models.tensorgraph.layers import SoftMaxCrossEntropy
from deepchem.models.tensorgraph.layers import SparseSoftMaxCrossEntropy
from deepchem.models.tensorgraph.layers import StopGradient
from deepchem.models.tensorgraph.layers import TensorWrapper
from deepchem.models.tensorgraph.layers import TimeSeriesDense
Expand Down Expand Up @@ -381,6 +383,18 @@ def test_softmax_cross_entropy(self):
out_tensor = out_tensor.eval()
assert out_tensor.shape == (batch_size,)

def test_sparse_softmax_cross_entropy(self):
batch_size = 10
n_features = 5
logit_tensor = np.random.rand(batch_size, n_features)
label_tensor = np.random.rand(batch_size)
with self.test_session() as sess:
logit_tensor = tf.convert_to_tensor(logit_tensor, dtype=tf.float32)
label_tensor = tf.convert_to_tensor(label_tensor, dtype=tf.int32)
out_tensor = SparseSoftMaxCrossEntropy()(label_tensor, logit_tensor)
out_tensor = out_tensor.eval()
assert out_tensor.shape == (batch_size,)

def test_reduce_mean(self):
"""Test that ReduceMean can be invoked."""
batch_size = 10
Expand Down Expand Up @@ -875,3 +889,16 @@ def test_IRV(self):
assert out_tensor.shape == (batch_size, n_tasks)
irv_reg = IRVRegularize(irv_layer, 1.)()
assert irv_reg.eval() >= 0

def test_hingeloss(self):

labels = 1
logits = 1
logits_tensor = np.random.rand(logits)
labels_tensor = np.random.rand(labels)
with self.test_session() as sess:
logits_tensor = tf.convert_to_tensor(logits_tensor, dtype=tf.float32)
labels_tensor = tf.convert_to_tensor(labels_tensor, dtype=tf.float32)
out_tensor = Hingeloss()(labels_tensor, logits_tensor)
out_tensor = out_tensor.eval()
assert out_tensor.shape == (labels,)
23 changes: 22 additions & 1 deletion deepchem/models/tensorgraph/tests/test_layers_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
SoftMaxCrossEntropy, ReduceMean, ToFloat, ReduceSquareDifference, Conv2D, MaxPool2D, ReduceSum, GraphConv, GraphPool, \
GraphGather, BatchNorm, WeightedError, ReLU, \
Conv3D, MaxPool3D, Conv2DTranspose, Conv3DTranspose, \
LSTMStep, AttnLSTMEmbedding, IterRefLSTMEmbedding, GraphEmbedPoolLayer, GraphCNN, Cast
LSTMStep, AttnLSTMEmbedding, IterRefLSTMEmbedding, GraphEmbedPoolLayer, GraphCNN, Cast,Hingeloss,SparseSoftMaxCrossEntropy
from deepchem.models.tensorgraph.symmetry_functions import AtomicDifferentiatedDense
from deepchem.models.tensorgraph.IRV import IRVLayer, IRVRegularize, Slice

Expand Down Expand Up @@ -269,6 +269,17 @@ def test_SoftmaxCrossEntropy_pickle():
tg.save()


def test_SparseSoftmaxCrossEntropy_pickle():
tg = TensorGraph()
logits = Feature(shape=(tg.batch_size, 5))
labels = Feature(shape=(tg.batch_size,), dtype=tf.int32)
layer = SparseSoftMaxCrossEntropy(in_layers=[labels, logits])
tg.add_output(layer)
tg.set_loss(layer)
tg.build()
tg.save()


def test_SigmoidCrossEntropy_pickle():
tg = TensorGraph()
feature = Feature(shape=(tg.batch_size, 1))
Expand Down Expand Up @@ -682,3 +693,13 @@ def test_Slice_pickle():
tg.set_loss(out)
tg.build()
tg.save()


def test_hingeloss_pickle():
tg = TensorGraph()
feature = Feature(shape=(1, None))
layer = Hingeloss(in_layers=[feature, feature])
tg.add_output(layer)
tg.set_loss(layer)
tg.build()
tg.save()

0 comments on commit d5282b3

Please sign in to comment.