Skip to content

Commit

Permalink
propagate nan in some activations (#8033)
Browse files Browse the repository at this point in the history
* propagate nan in some activations

* fix py2 not having math.nan

* flake8
  • Loading branch information
ssnl authored Jun 1, 2018
1 parent 8b447fa commit bf29abd
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 18 deletions.
4 changes: 2 additions & 2 deletions aten/src/TH/generic/THTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2803,13 +2803,13 @@ void THTensor_(cmin)(THTensor *r, THTensor *t, THTensor *src) {
void THTensor_(cmaxValue)(THTensor *r, THTensor *t, real value) {
THTensor_(resizeAs)(r, t);
TH_TENSOR_APPLY2(real, r, real, t,
*r_data = *t_data > value ? *t_data : value;);
*r_data = *t_data < value ? value : *t_data;); // this order propagates NaN
}

void THTensor_(cminValue)(THTensor *r, THTensor *t, real value) {
THTensor_(resizeAs)(r, t);
TH_TENSOR_APPLY2(real, r, real, t,
*r_data = *t_data < value ? *t_data : value;);
*r_data = *t_data > value ? value : *t_data;); // this order propagates NaN
}

void THTensor_(zeros)(THTensor *r_, THLongStorage *size)
Expand Down
8 changes: 4 additions & 4 deletions aten/src/THC/THCTensorMathPointwise.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -655,11 +655,11 @@ struct TensorMaxValueOp {
TensorMaxValueOp(T v) : val(v) {}

__device__ __forceinline__ void operator()(T* out) {
*out = THCNumerics<T>::gt(*out, val) ? *out : val;
*out = THCNumerics<T>::lt(*out, val) ? val : *out; // this order propagates NaN
}

__device__ __forceinline__ void operator()(T* out, T* in) {
*out = THCNumerics<T>::gt(*in, val) ? *in : val;
*out = THCNumerics<T>::lt(*in, val) ? val : *in; // this order propagates NaN
}

T val;
Expand All @@ -670,11 +670,11 @@ struct TensorMinValueOp {
TensorMinValueOp(T v) : val(v) {}

__device__ __forceinline__ void operator()(T* out) {
*out = THCNumerics<T>::lt(*out, val) ? *out : val;
*out = THCNumerics<T>::gt(*out, val) ? val : *out; // this order propagates NaN
}

__device__ __forceinline__ void operator()(T* out, T* in) {
*out = THCNumerics<T>::lt(*in, val) ? *in : val;
*out = THCNumerics<T>::gt(*in, val) ? val : *in; // this order propagates NaN
}

T val;
Expand Down
6 changes: 3 additions & 3 deletions aten/src/THCUNN/HardTanh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ struct hardtanhupdateOutput_functor
{
if (*input < min_val_)
*output = min_val_;
else if (*input <= max_val_)
*output = *input;
else
else if (*input > max_val_)
*output = max_val_;
else
*output = *input;
}

__device__ void operator()(T *input) const
Expand Down
12 changes: 6 additions & 6 deletions aten/src/THNN/generic/HardShrink.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ void THNN_(HardShrink_updateOutput)(
TH_TENSOR_APPLY2(real, output, real, input,
if (*input_data > lambda)
*output_data = *input_data;
else if (*input_data < -lambda)
*output_data = *input_data;
else
else if (*input_data >= -lambda)
*output_data = 0;
else
*output_data = *input_data; // let NaN case pass through here
);
}

Expand All @@ -32,10 +32,10 @@ void THNN_(HardShrink_updateGradInput)(
THNN_CHECK_NELEMENT(input, gradOutput);
THTensor_(resizeAs)(gradInput, input);
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
if (*input_data > lambda || *input_data < -lambda)
*gradInput_data = *gradOutput_data;
else
if (*input_data >= -lambda && *input_data <= lambda)
*gradInput_data = 0;
else
*gradInput_data = *gradOutput_data; // let NaN case pass through here
);
}

Expand Down
6 changes: 3 additions & 3 deletions aten/src/THNN/generic/HardTanh.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ void THNN_(HardTanh_updateOutput)(
TH_TENSOR_APPLY2(real, output, real, input,
if (*input_data < min_val)
*output_data = min_val;
else if (*input_data <= max_val)
*output_data = *input_data;
else
else if (*input_data > max_val)
*output_data = max_val;
else
*output_data = *input_data;
);
}
}
Expand Down
40 changes: 40 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,46 @@ def test_vector_to_parameters(self):
sample = next(model.parameters())[0, 0, 0]
self.assertTrue(torch.equal(sample.data, vec.data[:5]))

# We don't want to make propagating NaN a hard requirement on ops, but for
# these easy ones, we should make them do so.
def _test_nonlinearity_propagate_nan(self, device):
nan = float('nan')

def test(nonlinearity, *args, **kwargs):
x = torch.tensor([nan], device=device)
fn = getattr(F, nonlinearity)
try:
self.assertTrue(math.isnan(fn(x, *args, **kwargs).item()))
except Exception as e:
if 'not implemented' not in str(e):
raise

test('relu')
test('relu', inplace=True)
test('relu6')
test('elu')
test('selu')
test('rrelu')
test('rrelu', inplace=True)
test('hardtanh')
test('tanh')
test('sigmoid')
test('logsigmoid')
test('hardshrink')
test('tanhshrink')
test('softsign')
test('softmin', 0)
test('softmax', 0)
test('log_softmax', 0)
test('leaky_relu', 0.2)

def test_nonlinearity_propagate_nan(self):
self._test_nonlinearity_propagate_nan('cpu')

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_nonlinearity_propagate_nan_cuda(self):
self._test_nonlinearity_propagate_nan('cuda')

def test_weight_norm(self):
input = torch.randn(3, 5)
m = nn.Linear(5, 7)
Expand Down

0 comments on commit bf29abd

Please sign in to comment.