Skip to content

Commit

Permalink
[quant] Make quantization related nodes more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
tlepley-cadence authored and bertmaher committed Dec 4, 2018
1 parent 31a09e7 commit 2823858
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
13 changes: 6 additions & 7 deletions lib/Graph/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1468,9 +1468,7 @@ BatchOneHotNode *Function::createBatchOneHot(llvm::StringRef name,
QuantizeNode *Function::createQuantize(llvm::StringRef name, NodeValue input,
TypeRef outTy) {
assert(input.getType()->isFPType() && "Input must be a floating type");
assert((outTy->getElementType() == ElemKind::Int8QTy ||
outTy->getElementType() == ElemKind::Int32QTy) &&
"Output must be a quantized type");
assert(outTy->isQuantizedType() && "Output must be a quantized type");
assert(input.dims().equals(outTy->dims()) &&
"Different dimensions for input and output");

Expand All @@ -1480,14 +1478,16 @@ QuantizeNode *Function::createQuantize(llvm::StringRef name, NodeValue input,

DequantizeNode *Function::createDequantize(llvm::StringRef name,
NodeValue input) {
assert(input.getType()->isQuantizedType() &&
"Input must be a quantized type");
TypeRef outTy =
getParent()->uniqueType(Type(ElemKind::FloatTy, input.dims()));
return createDequantize(name, input, outTy);
}

DequantizeNode *Function::createDequantize(llvm::StringRef name,
NodeValue input, TypeRef outTy) {
assert(input.getElementType() == ElemKind::Int8QTy &&
assert(input.getType()->isQuantizedType() &&
"Input must be a quantized type");
assert(outTy->isFPType() && "Output should be an FP type");
return addNode(new DequantizeNode(name, outTy, input));
Expand All @@ -1496,10 +1496,9 @@ DequantizeNode *Function::createDequantize(llvm::StringRef name,
RescaleQuantizedNode *Function::createRescaleQuantized(llvm::StringRef name,
NodeValue input,
TypeRef outTy) {
assert(input.getElementType() == ElemKind::Int8QTy &&
assert(input.getType()->isQuantizedType() &&
"Input must be a quantized type");
assert(outTy->getElementType() == ElemKind::Int8QTy &&
"Output must be a quantized type");
assert(outTy->isQuantizedType() && "Output must be a quantized type");
assert(input.dims().equals(outTy->dims()) &&
"Different dimensions for input and output");

Expand Down
15 changes: 8 additions & 7 deletions lib/Graph/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,20 +762,21 @@ bool QuantizeNode::verify() const {
}

bool DequantizeNode::verify() const {
// Dest must be an FP type.
bool isValid = expectCompareTrue(
"Dest must be an FP type", getResult().getType()->isFPType(), true, this);
// Src must be quantized.
isValid &= checkType(getInput(), ElemKind::Int8QTy, this);
isValid &=
expectCompareTrue("Src must be quantized",
getInput().getType()->isQuantizedType(), true, this);
isValid &= checkSameShape(getResult(), getInput(), this);
return isValid;
}

bool RescaleQuantizedNode::verify() const {
// Dest must be quantized.
bool isValid = checkType(getResult(), ElemKind::Int8QTy, this);
// Src must be quantized.
isValid &= checkType(getInput(), ElemKind::Int8QTy, this);
bool isValid =
expectCompareTrue("Dest must be quantized",
getResult().getType()->isQuantizedType(), true, this);
isValid &=
checkType(getResult(), getInput().getType()->getElementType(), this);
isValid &= checkSameShape(getResult(), getInput(), this);
return isValid;
}
Expand Down

0 comments on commit 2823858

Please sign in to comment.