Skip to content

Commit

Permalink
Added ReLU6 implementation and test.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathantompson committed Jun 14, 2016
1 parent 5c2b253 commit 6161049
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
92 changes: 92 additions & 0 deletions ReLU6.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "THCUNN.h"
#include "common.h"

struct ReLU6UpdateOutput
{
ReLU6UpdateOutput() {}

__device__ __forceinline__ void operator()(float *out, float *in)
{
float x = *in;
*out = (x > 0) ? ((x < 6) ? x : 6) : 0;
}
};

// in-place variant
struct ReLU6UpdateOutputIP
{
ReLU6UpdateOutputIP() {}

__device__ __forceinline__ void operator()(float *x)
{
*x = (*x > 0) ? ((*x < 6) ? *x : 6) : 0;
}
};

void THNN_CudaReLU6_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output,
bool inplace)
{
THCUNN_assertSameGPU(state, 2, input, output);

if (inplace)
{
THC_pointwiseApply1(state, input,
ReLU6UpdateOutputIP()
);
THCudaTensor_set(state, output, input);
}
else
{
THCudaTensor_resizeAs(state, output, input);
THC_pointwiseApply2(state, output, input,
ReLU6UpdateOutput()
);
}

THCudaCheck(cudaGetLastError());
}

struct ReLU6UpdateGradInput
{
ReLU6UpdateGradInput() {}

__device__ __forceinline__ void operator()(
float *gradInput, float *input, float *gradOutput) const
{
*gradInput = (*input > 0 && *input < 6) ? *gradOutput : 0;
}
};

struct ReLU6UpdateGradInputIP
{
ReLU6UpdateGradInputIP() {}

__device__ __forceinline__ void operator()(
float *gradOutput, float *input) const
{
*gradOutput = (*input > 0 && *input < 6) ? *gradOutput : 0;
}
};

void THNN_CudaReLU6_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput,
THCudaTensor *gradInput, bool inplace)
{
THCUNN_assertSameGPU(state, 3, input, gradInput, gradOutput);

if (inplace)
{
THC_pointwiseApply2(state, gradOutput, input,
ReLU6UpdateGradInputIP()
);
THCudaTensor_set(state, gradInput, gradOutput);
}
else
{
THCudaTensor_resizeAs(state, gradInput, input);
THC_pointwiseApply3(state, gradInput, input, gradOutput,
ReLU6UpdateGradInput()
);
}

THCudaCheck(cudaGetLastError());
}
12 changes: 12 additions & 0 deletions THCUNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,18 @@ TH_API void THNN_CudaThreshold_updateGradInput(
double threshold,
bool inplace);

TH_API void THNN_CudaReLU6_updateOutput(
THCState *state,
THCudaTensor *input,
THCudaTensor *output,
bool inplace);
TH_API void THNN_CudaReLU6_updateGradInput(
THCState *state,
THCudaTensor *input,
THCudaTensor *gradOutput,
THCudaTensor *gradInput,
bool inplace);

TH_API void THNN_CudaTemporalConvolution_updateOutput(
THCState *state,
THCudaTensor *input,
Expand Down

0 comments on commit 6161049

Please sign in to comment.