Skip to content

Commit

Permalink
Fix problem in quantized version of Comparison op handler
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 215801773
  • Loading branch information
tensorflower-gardener committed Oct 4, 2018
1 parent 4c1da53 commit a2e48d8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
16 changes: 5 additions & 11 deletions tensorflow/contrib/lite/kernels/comparisons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,31 +66,25 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
if (input1->type == kTfLiteUInt8) { \
auto input1_offset = -input1->params.zero_point; \
auto input2_offset = -input2->params.zero_point; \
const int left_shift = 20; \
const double twice_max_input_scale = \
2 * std::max(input1->params.scale, input2->params.scale); \
const double real_input1_multiplier = \
input1->params.scale / twice_max_input_scale; \
const double real_input2_multiplier = \
input2->params.scale / twice_max_input_scale; \
const int left_shift = 8; \
\
int32 input1_multiplier; \
int input1_shift; \
QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \
QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \
&input1_multiplier, &input1_shift); \
int32 input2_multiplier; \
int input2_shift; \
QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \
QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \
&input2_multiplier, &input2_shift); \
\
ComparisonParams op_params; \
op_params.left_shift = left_shift; \
op_params.input1_offset = input1_offset; \
op_params.input1_multiplier = input1_multiplier; \
op_params.input1_shift = -input1_shift; \
op_params.input1_shift = input1_shift; \
op_params.input2_offset = input2_offset; \
op_params.input2_multiplier = input2_multiplier; \
op_params.input2_shift = -input2_shift; \
op_params.input2_shift = input2_shift; \
if (requires_broadcast) { \
reference_ops::Broadcast4DSlow##opname##WithScaling( \
op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/contrib/lite/kernels/comparisons_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,17 @@ TEST(ComparisonsTest, GreaterQuantized) {
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
}

TEST(ComparisonsTest, GreaterQuantizedSmallRange) {
ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, 0.0, 1.0},
{TensorType_UINT8, {1, 2, 2, 1}, 0.0, 2.0},
TensorType_UINT8, BuiltinOperator_GREATER);
model.QuantizeAndPopulate<uint8_t>(model.input1(), {1.0, 0.5, 0.35, 0.1});
model.QuantizeAndPopulate<uint8_t>(model.input2(), {1.01, 0.25, 0.3, 0.4});
model.Invoke();

EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
}

TEST(ComparisonsTest, GreaterEqualQuantized) {
const float kMin = -1.f;
const float kMax = 128.f;
Expand Down

0 comments on commit a2e48d8

Please sign in to comment.