Skip to content

Commit

Permalink
[Quantization] Dequantize(rescale) -> Dequantize() (pytorch#2077)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlepley-cadence authored and rdzhabarov committed Nov 27, 2018
1 parent 843ae43 commit c34d599
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 11 deletions.
11 changes: 11 additions & 0 deletions lib/Optimizer/GraphOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1901,6 +1901,17 @@ static void optimizeQuantization(Function *F) {
DQ->getResult().replaceAllUsesOfWith(Q->getInput());
continue;
}
// Fold the rescale into the following Dequantize.
// Dequantize(rescale) -> Dequantize()
if (auto *RS = dyn_cast<RescaleQuantizedNode>(DQ->getInput())) {
auto *newRS = F->createDequantize(DQ->getName(), RS->getInput());
DQ->getResult().replaceAllUsesOfWith(newRS);

// We may be able to optimize this rescale node. Remember to visit this
// new node and try to optimize it later.
worklist.push_back(newRS);
continue;
}
}

if (auto *RS = dyn_cast<RescaleQuantizedNode>(node)) {
Expand Down
25 changes: 25 additions & 0 deletions tests/unittests/graphOptzTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,31 @@ TEST_F(GraphOptz, foldQuantizeIntoVarMultipleUsages) {
}
}

/// Check that rescale gets correctly merged into a following dequantize node
TEST_F(GraphOptz, mergeRescaleIntoDequantize) {
// Check that we are combining quantization-dequantization pairs.
auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
"input", true);
auto *qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03f, 5);
auto *R = F_->createRescaleQuantized("rescale", input, qType);
auto *D = F_->createDequantize("dequantize", R);
F_->createSave("ret", D);

EXPECT_EQ(F_->getNodes().size(), 3);
::glow::optimize(F_, CompilationMode::Infer);

// Only 2 nodes should remain (Dequantize -> Save)
EXPECT_EQ(F_->getNodes().size(), 2);
// Check the graph structure
auto *SN = F_->getNodeByName("ret");
EXPECT_NE(nullptr, SN);
auto *S = llvm::dyn_cast<SaveNode>(SN);
EXPECT_NE(nullptr, S);
auto *newDN = S->getInput().getNode();
EXPECT_NE(nullptr, newDN);
EXPECT_NE(nullptr, llvm::dyn_cast<DequantizeNode>(newDN));
}

TEST_F(GraphOptz, quantizeToRescale) {
// Check that we are combining quantization-dequantization pairs.
auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
Expand Down
18 changes: 7 additions & 11 deletions tests/unittests/quantizationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1155,11 +1155,9 @@ TEST(Quantization, quantizeSlice) {
// Verify that the slice is rescaled after being quantized.
// The reason we need a rescale is because slicing doesn't perform rescaling
// by itself.
auto *RN = llvm::dyn_cast<RescaleQuantizedNode>(DN->getInput());
ASSERT_TRUE(RN);
EXPECT_EQ(RN->getResult().getType()->getOffset(), -128);
EXPECT_EQ(RN->getResult().getType()->getScale(), 0.2f);
auto *qslice = llvm::dyn_cast<SliceNode>(RN->getInput());
// Note: after optimization, the RescaleQuantized node created for the Slice
// gets merged with the dequantize node.
auto *qslice = llvm::dyn_cast<SliceNode>(DN->getInput());
ASSERT_TRUE(qslice);
ASSERT_TRUE(qslice->getResult().getType()->isQuantizedType());
EXPECT_EQ(qslice->getResult().getType()->getOffset(), 0);
Expand Down Expand Up @@ -1223,14 +1221,12 @@ TEST(Quantization, quantizeReshape) {
auto *DN = llvm::dyn_cast<DequantizeNode>(SN->getInput());
ASSERT_TRUE(DN);

// Verify that the slice is rescaled after being quantized.
// Verify that the reshape is rescaled after being quantized.
// The reason we need a rescale is because reshaping doesn't perform
// rescaling by itself.
auto *RN = llvm::dyn_cast<RescaleQuantizedNode>(DN->getInput());
ASSERT_TRUE(RN);
EXPECT_EQ(RN->getResult().getType()->getOffset(), -128);
EXPECT_EQ(RN->getResult().getType()->getScale(), 0.2f);
auto *qreshape = llvm::dyn_cast<ReshapeNode>(RN->getInput());
// Note: after optimization, the RescaleQuantized node created for the
// Reshape gets merged with the dequantize node.
auto *qreshape = llvm::dyn_cast<ReshapeNode>(DN->getInput());
ASSERT_TRUE(qreshape);
ASSERT_TRUE(qreshape->getResult().getType()->isQuantizedType());
EXPECT_EQ(qreshape->getResult().getType()->getOffset(), 0);
Expand Down

0 comments on commit c34d599

Please sign in to comment.