Skip to content

Commit

Permalink
[circle-quantizer] Propagate qparam of Pack (Samsung#7349)
Browse files Browse the repository at this point in the history
* [circle-quantizer] Propagate qparam of Pack

This propagates qparam of Pack.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening authored Aug 2, 2021
1 parent f633cee commit d36d6f3
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 0 deletions.
75 changes: 75 additions & 0 deletions compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,76 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
bool visit(luci::CircleNode *) { return false; }
};

/** EXAMPLE
*
* BEFORE
*
* [CircleNode] [CircleConst]
* (qparam1) (FP32)
* \ /
* \ /
* [CirclePack]
* (qparam2)
*
* AFTER
*
* [CircleNode] [CircleConst] [CircleConst] <- Dead node
* (qparam2) (qparam2) (FP32)
* \ /
* \ /
* [CirclePack]
* (qparam2)
*
* NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs.
*/
void propagate_pack_quantparam(luci::CirclePack *pack, loco::DataType quant_type)
{
assert(pack->quantparam() != nullptr);

const auto num_inputs = pack->values_count();

for (uint32_t i = 0; i < num_inputs; i++)
{
auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i));

// Skip if this input is PACK Op
if (node->opcode() == luci::CircleOpcode::PACK)
continue;

// Quantize constant values
if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
{
luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
if (const_node->dtype() != loco::DataType::FLOAT32)
throw std::runtime_error("Unsupported data type for constant input of pack Op");

const auto pack_qparam = pack->quantparam();
if (pack_qparam == nullptr)
throw std::runtime_error("quantparam of pack is not found during propagation");

assert(pack_qparam->scale.size() == 1);
assert(pack_qparam->zerop.size() == 1);
const auto scaling_factor = pack_qparam->scale[0];
const auto zerop = pack_qparam->zerop[0];

auto new_const = luci::clone(const_node);
quant_const_values(new_const, scaling_factor, zerop, quant_type);
pack->values(i, new_const);
overwrite_quantparam(pack, new_const);
}
else
{
const auto succs = loco::succs(node);
if (succs.size() > 1)
continue;

// Non-const input must have been quantized
assert(node->quantparam() != nullptr);
overwrite_quantparam(pack, node);
}
}
}

/**
* @brief Quantize const input tensors using min/max of const values
*/
Expand Down Expand Up @@ -1135,6 +1205,11 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type)
propagate_pad_v2_quantparam(loco::must_cast<CirclePadV2 *>(node), output_type);
break;

case luci::CircleOpcode::PACK:
// Quant param is propagated from output to inputs
propagate_pack_quantparam(loco::must_cast<CirclePack *>(node), output_type);
break;

default:
for (uint32_t i = 0; i < arity; i++)
{
Expand Down
68 changes: 68 additions & 0 deletions compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,39 @@ class DepthToSpaceTestGraph final : public SimpleTestGraph
luci::CircleDepthToSpace *_dtos = nullptr;
};

class PackTestGraph final : public SimpleTestGraph
{
public:
void init(void) override
{
TestIOGraph::init({16}, {32});
_param = create_dummy_const<Type::FLOAT32>(g(), {16});
_pack = g()->nodes()->create<luci::CirclePack>(2);
{
_pack->values(0, input());
_pack->values(1, _param);
_pack->axis(0);
}
output()->from(_pack);

set_minmax_to_non_const(g(), -1, 1);

// Set min/max of the input
// pack's qparam will be propagted, overwritten to the input
auto input = loco::must_cast<luci::CircleNode *>(pack()->values(0));
auto qp = input->quantparam();
qp->min[0] = -0.5;
qp->max[0] = 0.5;
}

public:
luci::CirclePack *pack(void) { return _pack; }

private:
luci::CirclePack *_pack = nullptr;
luci::CircleConst *_param = nullptr;
};

class PadTestGraph final : public SimpleTestGraph
{
public:
Expand Down Expand Up @@ -1505,6 +1538,41 @@ TEST(QuantizedModelVerifierTest, Tanh_wrong_granularity_NEG)
SUCCEED();
}

TEST(QuantizedModelVerifierTest, Pack)
{
TEST_WITH_GRAPH(PackTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(PackTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PackTestGraph, Type::S16, Granularity::ChannelWise);

// Test if Pack's qparam is propagated to the input
{
PackTestGraph g;
g.init();
quantize_and_verify(g.g(), Type::U8, Granularity::ChannelWise);
auto input = loco::must_cast<luci::CircleNode *>(g.pack()->values(0));
auto qp = input->quantparam();
EXPECT_FLOAT_EQ(2.0 / 255.0, qp->scale[0]);
EXPECT_FLOAT_EQ(128, qp->zerop[0]);
}
SUCCEED();
}

TEST(QuantizedModelVerifierTest, Pack_wrong_type_NEG)
{
TEST_WITH_WRONG_TYPE(PackTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
TEST_WITH_WRONG_TYPE(PackTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
TEST_WITH_WRONG_TYPE(PackTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
SUCCEED();
}

TEST(QuantizedModelVerifierTest, Pack_wrong_granularity_NEG)
{
TEST_WITH_WRONG_GRANULARITY(PackTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_WRONG_GRANULARITY(PackTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_WRONG_GRANULARITY(PackTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}

TEST(QuantizedModelVerifierTest, Pad)
{
TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::LayerWise);
Expand Down
10 changes: 10 additions & 0 deletions compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ struct VerifyQuantizedNodeChannelWiseGranularity final : public luci::CircleNode
return true;
}

bool visit(const luci::CirclePack *node)
{
RETURN_FALSE_UNLESS(is_lwq(node))
for (uint32_t i = 0; i < node->values_count(); i++)
{
RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
}
return true;
}

bool visit(const luci::CirclePad *node)
{
RETURN_FALSE_UNLESS(is_lwq(node))
Expand Down
10 changes: 10 additions & 0 deletions compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,16 @@ struct VerifyQuantizedNodeLayerWiseGranularity final : public luci::CircleNodeVi
return true;
}

bool visit(const luci::CirclePack *node)
{
RETURN_FALSE_UNLESS(is_lwq(node))
for (uint32_t i = 0; i < node->values_count(); i++)
{
RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
}
return true;
}

bool visit(const luci::CirclePad *node)
{
RETURN_FALSE_UNLESS(is_lwq(node))
Expand Down
10 changes: 10 additions & 0 deletions compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ struct VerifyQuantizedNodeS16Type final : public luci::CircleNodeVisitor<bool>
return true;
}

bool visit(const luci::CirclePack *node)
{
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
for (uint32_t i = 0; i < node->values_count(); i++)
{
RETURN_FALSE_UNLESS(has_type(node->values(i), Type::S16))
}
return true;
}

bool visit(const luci::CirclePad *node)
{
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
Expand Down
10 changes: 10 additions & 0 deletions compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ struct VerifyQuantizedNodeU8Type final : public luci::CircleNodeVisitor<bool>
return true;
}

bool visit(const luci::CirclePack *node)
{
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
for (uint32_t i = 0; i < node->values_count(); i++)
{
RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8))
}
return true;
}

bool visit(const luci::CirclePad *node)
{
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
Expand Down

0 comments on commit d36d6f3

Please sign in to comment.