Skip to content

Commit

Permalink
Fix overflow of Sigmoid in CPU Backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Yinghai Lu authored and rdzhabarov committed Nov 7, 2018
1 parent b727ada commit cf2326c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
8 changes: 4 additions & 4 deletions lib/Backends/CPU/libjit/libjit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,8 @@ int8_t libjit_elementselect_kernel_i8(size_t idx, const int8_t *cond,
}

DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_sigmoid_kernel_f) {
float e = expf(LHS[idx]);
return e / (e + 1);
float e = expf(-LHS[idx]);
return 1 / (e + 1);
}
DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_element_maxsplat_kernel_f,
float, MAX(LHS[idx], val))
Expand Down Expand Up @@ -1287,8 +1287,8 @@ void libjit_softmax_grad_f(float *inG, float *outW, const size_t *selectedW,

void libjit_sigmoid_f(const float *inW, float *outW, size_t numElem) {
for (size_t i = 0; i < numElem; i++) {
float e = expf(inW[i]);
outW[i] = e / (e + 1);
float e = expf(-inW[i]);
outW[i] = 1 / (e + 1);
}
}

Expand Down
18 changes: 18 additions & 0 deletions tests/unittests/OperatorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3381,6 +3381,24 @@ TEST_P(InterpAndCPU, NonSquareStrideMaxPool) {
EXPECT_EQ(result.getHandle().raw(i), ref[i]);
}

TEST_P(InterpAndCPU, SigmoidOverflow) {
auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {2}, "input", false);
auto IH = ctx_.allocate(input)->getHandle();
IH.raw(0) = 1000;
IH.raw(1) = -1000;

auto *fpSigmoid = F_->createSigmoid("fpSigmoid", input);
auto *S = F_->createSave("fpSave", fpSigmoid);
ctx_.allocate(S->getPlaceholder());
EE_.compile(CompilationMode::Infer, F_, ctx_);
EE_.run(ctx_);
Tensor &result = *ctx_.get(S->getPlaceholder());
static const float ref[] = {1, 0};
for (size_t i = 0; i < 2; i++) {
EXPECT_EQ(result.getHandle().raw(i), ref[i]);
}
}

TEST_P(InterpAndCPU, Int8Sigmoid) {
constexpr size_t size = 10;
auto *input =
Expand Down

0 comments on commit cf2326c

Please sign in to comment.