Skip to content

Commit

Permalink
[quant] Extend the interpreter to support all quantized types for qua…
Browse files Browse the repository at this point in the history
…ntize/requantize and rescale nodes
  • Loading branch information
tlepley-cadence authored and bertmaher committed Dec 4, 2018
1 parent 2823858 commit 2dd7391
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 46 deletions.
21 changes: 14 additions & 7 deletions include/glow/Quantization/Base/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,29 @@ template <class SrcTy, class DestTy> DestTy clip(SrcTy in) {
return std::max<SrcTy>(mn, std::min<SrcTy>(mx, in));
}

/// Converts floating point value to DestTy (int8 or int32) based on the
/// Converts floating point value to DestTy (quantized type) based on the
/// quantization parameters \p TQP.
template <class DestTy = int8_t>
inline DestTy quantize(float input, const TensorQuantizationParams &TQP) {
float result = input / TQP.scale + TQP.offset;
return quantization::clip<int32_t, DestTy>((int32_t)nearbyintf(result));
}

/// Converts a floating point \p tensor to int8 or int32 based on the
/// Converts a quantized value (type eTy) to floating point based on the
/// quantization parameters \p TQP.
template <class eTy = int8_t>
inline float dequantize(eTy input, const TensorQuantizationParams &TQP) {
return TQP.scale * (input - TQP.offset);
}

/// Converts a floating point \p tensor to quantized tensor based on the
/// quantization parameters \p TQP and \p Ty.
Tensor quantizeTensor(const Tensor &tensor, const TensorQuantizationParams &TQP,
ElemKind Ty = ElemKind::Int8QTy);

/// Converts int8 quantized value back to floating point number based on
/// the quantization parameters \p TQP.
float dequantize(int8_t input, const TensorQuantizationParams &TQP);
/// Converts quantized tensor \p tensor to floating point tensor of type \p Ty
/// floatKind.
Tensor dequantizeTensor(const Tensor &tensor, ElemKind floatKind);

/// Convert the floating point quantization parameters \p scale and \p offset
/// into the integer sequence of:
Expand All @@ -120,8 +127,8 @@ std::vector<int8_t> createMapping(TypeRef inTy, TypeRef outTy,

/// Row-wise quantize the tensor \p input. The param \p input is a 2D
/// tensor (i.e. M * N), \p scales and \p offsets are generated by each row of
/// \p input, \p output is 2D tensor quantized from \p input using \p scales and
/// \p offsets for each row.
/// \p input, \p output is 2D tensor quantized from \p input using \p scales
/// and \p offsets for each row.
void tensorRowwiseQuantization(const Tensor &input, Tensor &output,
Tensor &scales, Tensor &offsets);

Expand Down
7 changes: 5 additions & 2 deletions lib/Backends/Interpreter/InterpreterFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "glow/Backends/CompiledFunction.h"
#include "glow/Base/Tensor.h"
#include "glow/Graph/Context.h"
#include "glow/Quantization/Base/Base.h"

#include "llvm/ADT/ArrayRef.h"

Expand Down Expand Up @@ -205,8 +206,10 @@ class InterpreterFunction final : public CompiledFunction {
template <typename ElemTy>
void fwdSparseToDenseInst_FloatImpl(const SparseToDenseInst *I);

template <typename ElemTy>
void fwdDequantizeInst_Impl(const DequantizeInst *I);
template <class eTy>
void fwdRescaleQuantizedInst_impl(Value *src, Value *dest,
TensorQuantizationParams &srcQ,
TensorQuantizationParams &destQ);
///@}
};

Expand Down
55 changes: 32 additions & 23 deletions lib/Backends/Interpreter/InterpreterNodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2172,26 +2172,29 @@ void InterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) {
destTensor->assign(&qTensor);
}

template <typename ElemTy>
void InterpreterFunction::fwdDequantizeInst_Impl(
const glow::DequantizeInst *I) {
staticAssertFloatingPointType(ElemTy);
auto *srcTensor = getTensor(I->getSrc());
TensorQuantizationParams params{srcTensor->getType().getScale(),
srcTensor->getType().getOffset()};

auto destHandle = getWeightHandle<ElemTy>(I->getDest());
auto srcHandle = srcTensor->getHandle<int8_t>();
for (size_t i = 0, e = destHandle.size(); i < e; ++i) {
destHandle.raw(i) = quantization::dequantize(srcHandle.raw(i), params);
}
}

/// Dequantize integer tensor. Scale and Offset are based
/// on the source tensor type.
void InterpreterFunction::fwdDequantizeInst(const glow::DequantizeInst *I) {
dispatchFloatingPointImpl(fwdDequantizeInst_Impl,
I->getDest()->getElementType(), I);
auto *srcTensor = getTensor(I->getSrc());
auto *destTensor = getTensor(I->getDest());
auto destTy = destTensor->getType();
Tensor fTensor =
quantization::dequantizeTensor(*srcTensor, destTy.getElementType());
destTensor->assign(&fTensor);
}

template <class eTy>
void InterpreterFunction::fwdRescaleQuantizedInst_impl(
Value *src, Value *dest, TensorQuantizationParams &srcQ,
TensorQuantizationParams &destQ) {

auto srcH = getWeightHandle<eTy>(src);
auto destH = getWeightHandle<eTy>(dest);

for (size_t i = 0, e = destH.size(); i < e; ++i) {
float val = quantization::dequantize(srcH.raw(i), srcQ);
destH.raw(i) = quantization::quantize(val, destQ);
}
}

void InterpreterFunction::fwdRescaleQuantizedInst(
Expand All @@ -2204,12 +2207,18 @@ void InterpreterFunction::fwdRescaleQuantizedInst(
TensorQuantizationParams srcQ{srcTy->getScale(), srcTy->getOffset()};
TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};

auto srcH = getWeightHandle<int8_t>(src);
auto destH = getWeightHandle<int8_t>(dest);

for (size_t i = 0, e = destH.size(); i < e; ++i) {
float val = quantization::dequantize(srcH.raw(i), srcQ);
destH.raw(i) = quantization::quantize(val, destQ);
switch (destTy->getElementType()) {
case ElemKind::Int8QTy:
fwdRescaleQuantizedInst_impl<int8_t>(src, dest, srcQ, destQ);
break;
case ElemKind::Int16QTy:
fwdRescaleQuantizedInst_impl<int16_t>(src, dest, srcQ, destQ);
break;
case ElemKind::Int32QTy:
fwdRescaleQuantizedInst_impl<int32_t>(src, dest, srcQ, destQ);
break;
default:
llvm_unreachable("Quantized type not supported");
}
}

Expand Down
10 changes: 0 additions & 10 deletions lib/IR/Instrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,3 @@ void InsertTensorInst::verify() const {
assert(getAxis() >= 0 && getAxis() < getDest()->dims().size() &&
"Axis must fit inside Dest dims.");
}

void QuantizeInst::verify() const {
assert((getDest()->getElementType() == ElemKind::Int8QTy ||
getDest()->getElementType() == ElemKind::Int32QTy) &&
"Invalid type");
assert((getSrc()->getElementType() == ElemKind::FloatTy ||
getSrc()->getElementType() == ElemKind::Float16Ty) &&
"Invalid type");
assert(getSrc()->dims() == getDest()->dims() && "Invalid shape");
}
51 changes: 48 additions & 3 deletions lib/Quantization/Base/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,59 @@ Tensor quantizeTensor(const Tensor &tensor, const TensorQuantizationParams &TQP,
assert(tensor.getType().isFPType() && "Type not supported yet");
if (Ty == ElemKind::Int8QTy) {
quantizeTensorUtil<int8_t>(&tmp, tensor);
} else {
} else if (Ty == ElemKind::Int16QTy) {
quantizeTensorUtil<int16_t>(&tmp, tensor);
} else if (Ty == ElemKind::Int32QTy) {
quantizeTensorUtil<int32_t>(&tmp, tensor);
} else {
llvm_unreachable("Quantized type not supported");
}
return tmp;
}

float dequantize(int8_t input, const TensorQuantizationParams &TQP) {
return TQP.scale * (input - TQP.offset);
template <class eTy = int8_t>
static void dequantizeTensorUtil(Tensor *dest, const Tensor &src) {
TensorQuantizationParams TQP{src.getType().getScale(),
src.getType().getOffset()};
auto srcHandle = src.getHandle<eTy>();
switch (dest->getElementType()) {
case ElemKind::FloatTy: {
auto destH = dest->getHandle<float>();
for (size_t i = 0, e = destH.size(); i < e; ++i) {
destH.raw(i) = quantization::dequantize<eTy>(
static_cast<eTy>(srcHandle.raw(i)), TQP);
}
break;
}
case ElemKind::Float16Ty: {
auto destH = dest->getHandle<float16>();
for (size_t i = 0, e = destH.size(); i < e; ++i) {
destH.raw(i) = quantization::dequantize<eTy>(
static_cast<eTy>(srcHandle.raw(i)), TQP);
}
break;
}
default:
llvm_unreachable("Cannot dequantize to the given type");
}
}

Tensor dequantizeTensor(const Tensor &tensor, ElemKind floatKind) {
assert(((floatKind == ElemKind::FloatTy) ||
(floatKind == ElemKind::Float16Ty)) &&
"Non supported output floating point type");
Tensor tmp(floatKind, tensor.dims());
auto Ty = tensor.getType().getElementType();
if (Ty == ElemKind::Int8QTy) {
dequantizeTensorUtil<int8_t>(&tmp, tensor);
} else if (Ty == ElemKind::Int16QTy) {
dequantizeTensorUtil<int16_t>(&tmp, tensor);
} else if (Ty == ElemKind::Int32QTy) {
dequantizeTensorUtil<int32_t>(&tmp, tensor);
} else {
llvm_unreachable("Input quantized type not supported");
}
return tmp;
}

QuantizationTransform32To8 quantizeScaleOffset32To8(float scale,
Expand Down
5 changes: 4 additions & 1 deletion tools/ClassGen/InstrGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,14 +480,17 @@ int main(int argc, char **argv) {
BB.newInstr("Quantize")
.addOperand("Dest", OperandKind::Out)
.addOperand("Src", OperandKind::In)
.autoVerify(VerifyKind::TypeCheck, {"Src", "isFPType()"})
.autoVerify(VerifyKind::TypeCheck, {"Dest", "isQuantizedType()"})
.autoVerify(VerifyKind::SameShape, {"Dest", "Src"})
.dataParallel()
.autoIRGen();

BB.newInstr("Dequantize")
.addOperand("Dest", OperandKind::Out)
.addOperand("Src", OperandKind::In)
.autoVerify(VerifyKind::TypeCheck, {"Dest", "isFPType()"})
.autoVerify(VerifyKind::SameElementType, {"Src", "ElemKind::Int8QTy"})
.autoVerify(VerifyKind::TypeCheck, {"Src", "isQuantizedType()"})
.autoVerify(VerifyKind::SameShape, {"Dest", "Src"})
.dataParallel()
.autoIRGen();
Expand Down

0 comments on commit 2dd7391

Please sign in to comment.