Skip to content

Commit

Permalink
inplace hardtanh, remove relu6
Browse files Browse the repository at this point in the history
  • Loading branch information
szagoruyko committed Jun 16, 2016
1 parent 6161049 commit d542253
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 113 deletions.
58 changes: 51 additions & 7 deletions HardTanh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,36 @@ struct hardtanhupdateOutput_functor
else
*output = max_val_;
}

__device__ void operator()(float *input) const
{
if (*input < min_val_)
*input = min_val_;
else if (*input > max_val_)
*input = max_val_;
}
};

void THNN_CudaHardTanh_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, float min_val, float max_val)
void THNN_CudaHardTanh_updateOutput(
THCState *state,
THCudaTensor *input,
THCudaTensor *output,
float min_val,
float max_val,
bool inplace)
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
THC_pointwiseApply2(state, output, input,
if(inplace)
{
THCudaTensor_set(state, output, input);
THC_pointwiseApply1(state, output, hardtanhupdateOutput_functor(min_val, max_val));
}
else
{
THCudaTensor_resizeAs(state, output, input);
THC_pointwiseApply2(state, output, input,
hardtanhupdateOutput_functor(min_val, max_val));
}
}

struct hardtanhupdateGradInput_functor
Expand All @@ -47,13 +69,35 @@ struct hardtanhupdateGradInput_functor
else
*gradInput = *gradOutput;
}

__device__ void operator()(float *gradInput, const float *input) const
{
if (*input <= min_val_ || *input >= max_val_)
*gradInput = 0;
}
};

void THNN_CudaHardTanh_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, float min_val, float max_val)
void THNN_CudaHardTanh_updateGradInput(
THCState *state,
THCudaTensor *input,
THCudaTensor *gradOutput,
THCudaTensor *gradInput,
float min_val,
float max_val,
bool inplace)
{
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);

THCudaTensor_resizeAs(state, gradInput, input);
THC_pointwiseApply3(state, gradInput, input, gradOutput,
hardtanhupdateGradInput_functor(min_val, max_val));
if (inplace)
{
THCudaTensor_resizeAs(state, gradInput, input);
THC_pointwiseApply3(state, gradInput, input, gradOutput,
hardtanhupdateGradInput_functor(min_val, max_val));
}
else
{
THCudaTensor_set(state, gradInput, gradOutput);
THC_pointwiseApply2(state, gradInput, input,
hardtanhupdateGradInput_functor(min_val, max_val));
}
}
92 changes: 0 additions & 92 deletions ReLU6.cu

This file was deleted.

18 changes: 4 additions & 14 deletions THCUNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,16 @@ TH_API void THNN_CudaHardTanh_updateOutput(
THCudaTensor *input,
THCudaTensor *output,
float min_val,
float max_val);
float max_val,
bool inplace);
TH_API void THNN_CudaHardTanh_updateGradInput(
THCState *state,
THCudaTensor *input,
THCudaTensor *gradOutput,
THCudaTensor *gradInput,
float min_val,
float max_val);
float max_val,
bool inplace);

TH_API void THNN_CudaL1Cost_updateOutput(
THCState *state,
Expand Down Expand Up @@ -403,18 +405,6 @@ TH_API void THNN_CudaThreshold_updateGradInput(
double threshold,
bool inplace);

TH_API void THNN_CudaReLU6_updateOutput(
THCState *state,
THCudaTensor *input,
THCudaTensor *output,
bool inplace);
TH_API void THNN_CudaReLU6_updateGradInput(
THCState *state,
THCudaTensor *input,
THCudaTensor *gradOutput,
THCudaTensor *gradInput,
bool inplace);

TH_API void THNN_CudaTemporalConvolution_updateOutput(
THCState *state,
THCudaTensor *input,
Expand Down

0 comments on commit d542253

Please sign in to comment.