From db369517d754df47940f338acd1af8b81e2fad0c Mon Sep 17 00:00:00 2001 From: Ina Dobreva <55383260+inadob@users.noreply.github.com> Date: Sun, 1 Dec 2019 00:16:44 +0000 Subject: [PATCH] [Relay][Frontend][TFlite] Add test for qnn_mul operator (#4395) * Add a function to set the qnn output range wrt each elemwise operation. * Add comments warning for nonsense clamped output in the tflite/tvm results comparison. --- tests/python/frontend/tflite/test_forward.py | 38 +++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 555dc579b0d8..ad7989f2da4f 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -143,7 +143,8 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, converter.inference_type = tf.lite.constants.QUANTIZED_UINT8 input_arrays = converter.get_input_arrays() input_stats = {} - # hardcode the mean_values and std_dev_values (m,s) to be the same for all inputs + # hardcode the mean_values and std_dev_values (m,s) to be the same + # if all inputs are in (float_min; float_max) == (-100, 100) # s = 255/(fmax-fmin); m = -fmin*s (the zero point) for i in input_arrays: input_stats[i] = (128., 1.275) @@ -160,6 +161,10 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device, num_output=len(out_names), out_names=out_names) + + # WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output + # range for the specific operator. While adding test ensure that we aren't getting only clipped values + # in output tensors that still pass the assertion. For reference see _test_elemwise_qnn_out_range() if quantized: for i in range(len(tflite_output)): # allow absolute tolerance of 1 in the quantized results @@ -562,7 +567,7 @@ def test_forward_concatenation(): # Element-wise # --- -def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False): +def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False, qnn_op=None): """ One iteration of elemwise """ assert len(data) == 2 @@ -578,7 +583,9 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals tf.quantization.fake_quant_with_min_max_args(in_data[1], min=-100, max=100, name="inq_1")] out = math_op(inq_data[0], inq_data[1]) out = with_fused_activation_function(out, fused_activation_function) - out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out") + # set the quantized output range with respect to the operation + out_min, out_max = _test_elemwise_qnn_out_range(qnn_op) + out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out") compare_tflite_with_tvm(data, ['inq_0:0', 'inq_1:0'], inq_data, [out], quantized=True) else: out = math_op(in_data[0], in_data[1]) @@ -595,7 +602,8 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals # the 2nd tensor is treated as constant and directly added as part of the operation out = math_op(inq_data, ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const')) out = with_fused_activation_function(out, fused_activation_function) - out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out") + out_min, out_max = _test_elemwise_qnn_out_range(qnn_op) + out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out") compare_tflite_with_tvm(data[0], ['inq_0:0'], inq_data, [out], quantized=True) else: out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype)) @@ -606,9 +614,9 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals # Add # --- -def _test_add(data, fused_activation_function=None, quantized=False): +def _test_add(data, fused_activation_function=None, quantized=False, qnn_op=None): """ One iteration of add """ - return _test_elemwise(math_ops.add, data, fused_activation_function, quantized) + return _test_elemwise(math_ops.add, data, fused_activation_function, quantized, qnn_op) ####################################################################### # Subtract @@ -620,9 +628,10 @@ def _test_sub(data, fused_activation_function=None): ####################################################################### # Mul # --- -def _test_mul(data, fused_activation_function=None): + +def _test_mul(data, fused_activation_function=None, quantized=False, qnn_op=None): """ One iteration of mul """ - return _test_elemwise(math_ops.multiply, data, fused_activation_function) + return _test_elemwise(math_ops.multiply, data, fused_activation_function, quantized, qnn_op) ####################################################################### # Divide @@ -671,7 +680,17 @@ def _test_forward_elemwise(testop): def _test_forward_elemwise_quantized(testop): testop([np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8), - np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8)], quantized=True) + np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8)], quantized=True, qnn_op=testop) + +def _test_elemwise_qnn_out_range(qnn_op): + # set the fake_quant output range if input tensors are in [-100, 100] float32 + qnn_out_range = { + _test_add: (-200, 200), + _test_sub: (-200, 200), + _test_mul: (-1e+4, 1e+4), + } + + return qnn_out_range[qnn_op] def test_all_elemwise(): _test_forward_elemwise(_test_add) @@ -682,6 +701,7 @@ def test_all_elemwise(): _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU")) _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU6")) _test_forward_elemwise(_test_mul) + _test_forward_elemwise_quantized(_test_mul) _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU")) _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU6")) _test_forward_elemwise(_test_div)