Skip to content

Commit

Permalink
Provide Backend instead of EE to Quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
jfix71 committed Apr 2, 2019
1 parent 46bd1bc commit 5cbc7d1
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 63 deletions.
4 changes: 2 additions & 2 deletions examples/fr2en.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ struct Model {

// Quantize the graph based on the captured profile.
auto *Q = glow::quantization::quantizeFunction(
EE_, quantization::Schema::Asymmetric, quantizationInfos,
ElemKind::Int8QTy, F_, loweredMap_);
*EE_.getBackend(), quantization::Schema::Asymmetric,
quantizationInfos, ElemKind::Int8QTy, F_, loweredMap_);

// Erase the original function so that the redundant variables that are
// only referenced by the original function will be removed.
Expand Down
4 changes: 2 additions & 2 deletions include/glow/Quantization/Quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

namespace glow {

class ExecutionEngine;
class Backend;

/// Tensor quantization parameters for a given node.
struct NodeQuantizationInfo {
Expand Down Expand Up @@ -103,7 +103,7 @@ std::vector<NodeQuantizationInfo> generateNodeQuantizationInfos(
/// nodes will be converted to RowwiseQuantizedFullyConnected. \returns a new
/// quantized function.
Function *quantizeFunction(
const ExecutionEngine &EE, quantization::Schema schema,
const Backend &B, quantization::Schema schema,
llvm::ArrayRef<NodeQuantizationInfo> quantizationInfos,
ElemKind quantizationPrecision, Function *F,
const LoweredInfoMap &loweredMap = {}, llvm::StringRef newFuncName = "",
Expand Down
4 changes: 2 additions & 2 deletions lib/Onnxifi/InlineOnnxifi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ InlineGraph::initGraph(const void *onnxModel, size_t onnxModelSize,
std::string oldName = function_->getName();
function_->setName("old");
auto *Q = quantization::quantizeFunction(
executionEngine_, quantization::Schema::Symmetric, QI,
*executionEngine_.getBackend(), quantization::Schema::Symmetric, QI,
ElemKind::Int8QTy, function_, loweredMap_, oldName, {}, false);
Q->getParent()->eraseFunction(function_);
function_ = Q;
Expand Down Expand Up @@ -105,4 +105,4 @@ InlineGraph::run(std::unique_ptr<ExecutionContext> ctx, EventPtr outputEvent,
}

} // namespace onnxifi
} // namespace glow
} // namespace glow
2 changes: 1 addition & 1 deletion lib/Quantization/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ target_link_libraries(Quantization
PRIVATE
Converter
Graph
ExecutionEngine
Backend
QuantizationBase
LLVMSupport)
20 changes: 9 additions & 11 deletions lib/Quantization/Quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

#include "glow/Quantization/Quantization.h"

#include "glow/Backends/Backend.h"
#include "glow/Converter/FunctionConverter.h"
#include "glow/ExecutionEngine/ExecutionEngine.h"

#include <cmath>
#include <unordered_set>
Expand Down Expand Up @@ -198,7 +198,7 @@ class FunctionQuantizer : public FunctionConverter {
}

// Only convert the node if the backend supports the newly converted node.
return EE_.isOpSupported(NodeInfo(node.getKind(), inputTypes, outputTypes));
return B_.isOpSupported(NodeInfo(node.getKind(), inputTypes, outputTypes));
}

/// Helper that \returns whether quantization parameters exist
Expand Down Expand Up @@ -421,9 +421,8 @@ class FunctionQuantizer : public FunctionConverter {
private:
/// Shortcut to the module of function_.
Module &mod_;
/// Execution engine used to check is a quantized operator is
/// supported.
const ExecutionEngine &EE_;
/// Backend used to check is a quantized operator is supported.
const Backend &B_;
/// Quantization schema.
quantization::Schema schema_;
/// Quantization precision.
Expand Down Expand Up @@ -467,15 +466,14 @@ class FunctionQuantizer : public FunctionConverter {
/// Creates a function quantizer for \p F using the quantization
/// parameters defined by \p quantizationInfos and target quantization
/// precision defined by \p quantizationPrecision.
/// \p EE and \p doNotQuantizeKinds are used to check which
/// \p B and \p doNotQuantizeKinds are used to check which
/// nodes shouldn't be converted.
FunctionQuantizer(Function &F, const ExecutionEngine &EE,
quantization::Schema schema,
FunctionQuantizer(Function &F, const Backend &B, quantization::Schema schema,
llvm::ArrayRef<NodeQuantizationInfo> quantizationInfos,
ElemKind quantizationPrecision,
const KindSet &doNotQuantizeKinds,
const LoweredInfoMap &loweredMap)
: FunctionConverter(F), mod_(*F.getParent()), EE_(EE), schema_(schema),
: FunctionConverter(F), mod_(*F.getParent()), B_(B), schema_(schema),
quantizationPrecision_(quantizationPrecision),
doNotQuantizeKinds_(doNotQuantizeKinds), loweredMap_(loweredMap) {
// Build a mapping between node name and TensorQuantizatonParams.
Expand Down Expand Up @@ -697,7 +695,7 @@ generateNodeQuantizationInfos(PlaceholderBindings &bindings, const Function *F,
}

Function *
quantizeFunction(const ExecutionEngine &EE, quantization::Schema schema,
quantizeFunction(const Backend &B, quantization::Schema schema,
llvm::ArrayRef<NodeQuantizationInfo> quantizationInfos,
ElemKind quantizationPrecision, Function *F,
const LoweredInfoMap &loweredMap, llvm::StringRef newFuncName,
Expand All @@ -713,7 +711,7 @@ quantizeFunction(const ExecutionEngine &EE, quantization::Schema schema,

Function *G = F->clone(newFuncName);

FunctionQuantizer quantizer(*G, EE, schema, quantizationInfos,
FunctionQuantizer quantizer(*G, B, schema, quantizationInfos,
quantizationPrecision, doNotQuantizeKinds,
loweredMap);
quantizer.convert();
Expand Down
12 changes: 6 additions & 6 deletions tests/unittests/BackendTestUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,17 @@ static void profileAndQuantize(PlaceholderBindings &Ibindings,
// Lower only as the backends prefer for actually quantizing.
LoweredInfoMap loweredMapForQuant;
lower(IF, &loweredMapForQuant, IEE.getBackend());
IF = quantization::quantizeFunction(IEE, schema, QI, interpElemKind, IF,
loweredMapForQuant, "quant", {},
enableRowwiseQuantization);
IF = quantization::quantizeFunction(*IEE.getBackend(), schema, QI,
interpElemKind, IF, loweredMapForQuant,
"quant", {}, enableRowwiseQuantization);
}
if (isQuantizedElemKind(backendElemKind)) {
// Lower only as the backends prefer for actually quantizing.
LoweredInfoMap loweredMapForQuant;
lower(BF, &loweredMapForQuant, BEE.getBackend());
BF = quantization::quantizeFunction(BEE, schema, QI, backendElemKind, BF,
loweredMapForQuant, "quant", {},
enableRowwiseQuantization);
BF = quantization::quantizeFunction(*BEE.getBackend(), schema, QI,
backendElemKind, BF, loweredMapForQuant,
"quant", {}, enableRowwiseQuantization);
}
}

Expand Down
8 changes: 4 additions & 4 deletions tests/unittests/GradCheckTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void performGradCheck(ExecutionEngine &EE, PlaceholderBindings &bindings,
float delta, float allowedError) {
TrainingConfig TC;

auto &F = *EE.getModule().getFunction("main");
auto *F = EE.getModule().getFunction("main");

// Allocate result, inputVar and expVar.
auto resultTensor = bindings.allocate(result->getPlaceholder());
Expand All @@ -103,7 +103,7 @@ void performGradCheck(ExecutionEngine &EE, PlaceholderBindings &bindings,
size_t sampleCounter = 0;

// Create a function that trains the network.
Function *TF = glow::differentiate(&F, TC);
Function *TF = glow::differentiate(F, TC);
EE.compile(CompilationMode::Train, TF);

// The network might have variables, other than inputVar and expVar.
Expand All @@ -114,7 +114,7 @@ void performGradCheck(ExecutionEngine &EE, PlaceholderBindings &bindings,
// Create a version of the network that records the gradients to some side
// table instead of updating them.
VariableGradientsList varGrads;
Function *recordNet = glow::differentiate(&F, TC, "record", &varGrads);
Function *recordNet = glow::differentiate(F, TC, "record", &varGrads);
allocateGrads(bindings, varGrads);
EE.compile(CompilationMode::Train, recordNet);

Expand All @@ -128,7 +128,7 @@ void performGradCheck(ExecutionEngine &EE, PlaceholderBindings &bindings,
{inputs, outputs});

// Compile the original network in inference mode.
EE.compile(CompilationMode::Infer, &F);
EE.compile(CompilationMode::Infer, F);

auto analyticalGradsH = gradVarTensor->getHandle();
auto inputsH = inputs->getHandle<>();
Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/HyphenTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ struct HyphenNetwork {
TrainingConfig &TC) {
// Compilation is destructive because of target-specific lowering.
// Compile a clone of the inference function.
EE.compile(CompilationMode::Infer, infer_->clone(name));
auto *CF = infer_->clone(name);
EE.compile(CompilationMode::Infer, CF);

auto batchSize = TC.batchSize;
auto numSamples = inputs.dims()[0];
Expand Down
12 changes: 6 additions & 6 deletions tests/unittests/MLTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1077,9 +1077,9 @@ TEST_P(InterpreterAndCPU, convNetForImageRecognition) {
// Build the new quantized graph.
LoweredInfoMap loweredMapForQuant;
lower(F, &loweredMapForQuant, EE.getBackend());
Function *QP =
quantization::quantizeFunction(EE, quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, F, loweredMapForQuant);
Function *QP = quantization::quantizeFunction(
*EE.getBackend(), quantization::Schema::Asymmetric, QI, ElemKind::Int8QTy,
F, loweredMapForQuant);

EE.compile(CompilationMode::Infer, QP);

Expand Down Expand Up @@ -1197,9 +1197,9 @@ TEST_P(InterpreterAndCPU, testFindPixelRegression) {
// Build the new quantized graph.
LoweredInfoMap loweredMapForQuant;
lower(F, &loweredMapForQuant, EE.getBackend());
Function *QP =
quantization::quantizeFunction(EE, quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, F, loweredMapForQuant);
Function *QP = quantization::quantizeFunction(
*EE.getBackend(), quantization::Schema::Asymmetric, QI, ElemKind::Int8QTy,
F, loweredMapForQuant);

EE.compile(CompilationMode::Infer, QP);

Expand Down
60 changes: 34 additions & 26 deletions tests/unittests/QuantizationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ void quantizeGraph(ElemKind quantizationPrecision) {
{NodeQuantizationInfo::generateNodeOutputName(FC->getName()), {0.6f, 0}},
};

F = quantization::quantizeFunction(EE, quantization::Schema::Asymmetric, QI,
F = quantization::quantizeFunction(*EE.getBackend(),
quantization::Schema::Asymmetric, QI,
quantizationPrecision, F);

// Make sure that graph can be compiled.
Expand Down Expand Up @@ -304,7 +305,8 @@ TEST(Quantization, enableRowwiseQuantizedFullyConnected) {
{NodeQuantizationInfo::generateNodeOutputName(FC->getName()), {0.6f, 0}},
};

F = quantization::quantizeFunction(EE, quantization::Schema::Asymmetric, QI,
F = quantization::quantizeFunction(*EE.getBackend(),
quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, F, {}, "", {}, true);

// Check the graph structure after quantization.
Expand Down Expand Up @@ -373,7 +375,8 @@ TEST(Quantization, enableRowwiseQuantizedSLWS) {
{0.4f, 0}},
};

F = quantization::quantizeFunction(EE, quantization::Schema::Asymmetric, QI,
F = quantization::quantizeFunction(*EE.getBackend(),
quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, F, {}, "", {},
/* enableRowwiseQuantization */ true);

Expand Down Expand Up @@ -407,7 +410,8 @@ TEST(Quantization, quantizeReLU) {
{0.2f, 0}},
{NodeQuantizationInfo::generateNodeOutputName(relu->getName()),
{0.2f, -128}}};
F = quantization::quantizeFunction(EE, quantization::Schema::Asymmetric, QI,
F = quantization::quantizeFunction(*EE.getBackend(),
quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, F);
EE.compile(CompilationMode::Infer, F);

Expand Down Expand Up @@ -520,7 +524,7 @@ void testQuantizationEnd2End(ExecutionEngine &profileEE,
LoweredInfoMap loweredMapForQuant;
lower(F2, &loweredMapForQuant, backendSpecificEE.getBackend());
F2 = quantization::quantizeFunction(
backendSpecificEE, quantization::Schema::Asymmetric, QI,
*backendSpecificEE.getBackend(), quantization::Schema::Asymmetric, QI,
quantizationPrecision, F2, loweredMapForQuant);
backendSpecificEE.compile(CompilationMode::Infer, F2);
backendSpecificEE.run(bindings);
Expand Down Expand Up @@ -674,7 +678,7 @@ TEST_P(Operator, end2endGRU) {
LoweredInfoMap loweredMapForQuant;
lower(F2, &loweredMapForQuant, backendSpecificEE.getBackend());
F2 = quantization::quantizeFunction(
backendSpecificEE, quantization::Schema::Asymmetric, QI,
*backendSpecificEE.getBackend(), quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, F2, loweredMapForQuant);
backendSpecificEE.compile(CompilationMode::Infer, F2);
backendSpecificEE.run(bindings);
Expand Down Expand Up @@ -989,7 +993,8 @@ TEST(Quantization, quantizeSoftmaxAndLRN) {
{NodeQuantizationInfo::generateNodeOutputName(SM->getName()), {0.4f, 0}},
{NodeQuantizationInfo::generateNodeOutputName(SN->getName()), {0.4f, 0}}};

F = quantization::quantizeFunction(EE, quantization::Schema::Asymmetric, QI,
F = quantization::quantizeFunction(*EE.getBackend(),
quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, F);

auto qLRNIt = std::find_if(
Expand Down Expand Up @@ -1042,7 +1047,8 @@ TEST(Quantization, quantizeSelect) {
{NodeQuantizationInfo::generateNodeOutputName(select->getName()),
selectQP}};

F = quantization::quantizeFunction(EE, quantization::Schema::Asymmetric, QI,
F = quantization::quantizeFunction(*EE.getBackend(),
quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, F);

auto it = std::find_if(
Expand Down Expand Up @@ -1087,7 +1093,8 @@ TEST(Quantization, quantizeAvgPool) {
{NodeQuantizationInfo::generateNodeOutputName(s->getName()), {0.4f, 0}},
};

F = quantization::quantizeFunction(EE, quantization::Schema::Asymmetric, QI,
F = quantization::quantizeFunction(*EE.getBackend(),
quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, F);

auto qPool = std::find_if(F->getNodes().begin(), F->getNodes().end(),
Expand Down Expand Up @@ -1138,8 +1145,8 @@ TEST(Quantization, quantizeGraphPartially) {
doNotQuantize.insert(Kinded::Kind::TanhNodeKind);

auto *QF = quantization::quantizeFunction(
EE, quantization::Schema::Asymmetric, QI, ElemKind::Int8QTy, F, {},
"_quantized", doNotQuantize);
*EE.getBackend(), quantization::Schema::Asymmetric, QI, ElemKind::Int8QTy,
F, {}, "_quantized", doNotQuantize);
QF->getParent()->eraseFunction(F);
F = QF;

Expand Down Expand Up @@ -1220,8 +1227,8 @@ TEST(Quantization, quantizeGraphPartiallyMultipleNodes) {
doNotQuantize.insert(Kinded::Kind::TanhNodeKind);

auto *QF = quantization::quantizeFunction(
EE, quantization::Schema::Asymmetric, QI, ElemKind::Int8QTy, F, {},
"_quantized", doNotQuantize);
*EE.getBackend(), quantization::Schema::Asymmetric, QI, ElemKind::Int8QTy,
F, {}, "_quantized", doNotQuantize);
QF->getParent()->eraseFunction(F);
F = QF;

Expand Down Expand Up @@ -1312,8 +1319,8 @@ TEST(Quantization, quantizeGraphPartiallyMultipleKinds) {
doNotQuantize.insert(Kinded::Kind::AddNodeKind);

auto *QF = quantization::quantizeFunction(
EE, quantization::Schema::Asymmetric, QI, ElemKind::Int8QTy, F, {},
"_quantized", doNotQuantize);
*EE.getBackend(), quantization::Schema::Asymmetric, QI, ElemKind::Int8QTy,
F, {}, "_quantized", doNotQuantize);
QF->getParent()->eraseFunction(F);
F = QF;

Expand Down Expand Up @@ -1394,9 +1401,9 @@ TEST(Quantization, quantizeFunctionConvertConstant) {
{NodeQuantizationInfo::generateNodeOutputName(MMN->getName()), {0.6f, 0}},
};

auto *QF =
quantization::quantizeFunction(EE, quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, F, {}, "_quantized");
auto *QF = quantization::quantizeFunction(
*EE.getBackend(), quantization::Schema::Asymmetric, QI, ElemKind::Int8QTy,
F, {}, "_quantized");
QF->getParent()->eraseFunction(F);
F = QF;

Expand Down Expand Up @@ -1459,9 +1466,9 @@ TEST(Quantization, quantizeSlice) {
{0.4f, 0}},
};

auto *QF =
quantization::quantizeFunction(EE, quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, F, {}, "_quantized");
auto *QF = quantization::quantizeFunction(
*EE.getBackend(), quantization::Schema::Asymmetric, QI, ElemKind::Int8QTy,
F, {}, "_quantized");
QF->getParent()->eraseFunction(F);
F = QF;

Expand Down Expand Up @@ -1530,9 +1537,9 @@ TEST(Quantization, quantizeReshape) {
{0.4f, 0}},
};

auto *QF =
quantization::quantizeFunction(EE, quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, F, {}, "_quantized");
auto *QF = quantization::quantizeFunction(
*EE.getBackend(), quantization::Schema::Asymmetric, QI, ElemKind::Int8QTy,
F, {}, "_quantized");
QF->getParent()->eraseFunction(F);
F = QF;

Expand Down Expand Up @@ -1736,8 +1743,9 @@ static void testProfileQuantizationOfFC(bool expectLoweredFC,
// Quantize the function given the current backend we're testing along with
// the quantization infos gathered.
backendF = quantization::quantizeFunction(
backendEE, quantization::Schema::Asymmetric, QI, ElemKind::Int8QTy,
backendF, loweredMapForQuant, "quant", {}, rowwiseQuantizeFC);
*backendEE.getBackend(), quantization::Schema::Asymmetric, QI,
ElemKind::Int8QTy, backendF, loweredMapForQuant, "quant", {},
rowwiseQuantizeFC);

// Compile the graph to remove dead code and optimize away unnecessary
// quantize nodes.
Expand Down
Loading

0 comments on commit 5cbc7d1

Please sign in to comment.