Skip to content

Commit

Permalink
Add clipping only inputs option and clip + convert Storage (pytorch#3855
Browse files Browse the repository at this point in the history
)

Summary:
WIP. Add option to convert only outputs of nodes. Also add clipping and converting of Storage nodes since we'd be no longer clipping inputs.
Pull Request resolved: pytorch#3855

Test Plan:
Added a couple tests. Other tests are failing and need to be fixed before landing.

CC: mjanderson09 nrsatish

Reviewed By: nrsatish

Differential Revision: D18919099

Pulled By: jfix71

fbshipit-source-id: 5e325afac88929c7edaa8e7d67c3e1d7b7212462
  • Loading branch information
jfix71 authored and facebook-github-bot committed Dec 18, 2019
1 parent cb46537 commit aab4944
Show file tree
Hide file tree
Showing 15 changed files with 254 additions and 86 deletions.
12 changes: 0 additions & 12 deletions include/glow/Backend/BackendUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,6 @@ class RuntimeBundle {
/// has_getFusedActivation that looks for said method.
CLASS_CONTAINS_METHOD(getFusedActivation)

/// If \p PH is an output placeholder in the Function \p F,
/// \returns true.
/// This is determined by checking if the PH has a user which uses the PH as an
/// overwritten input.
bool isOutput(const Placeholder *PH, const Function &F);

/// If \p PH is an input placeholderin the Function \p F,
/// \returns true.
/// This is determined by checking if the PH is the input to a saveNode or is
/// used by a non saveNode.
bool isInput(const Placeholder *PH, const Function &F);

/// If \p PH is an output placeholder in the IRFunction \p F,
/// \returns true.
/// This is determined by checking if the PH has weights which are referenced by
Expand Down
5 changes: 3 additions & 2 deletions include/glow/Converter/FunctionConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ class FunctionConverter {

/// Create a conversion with \p val as input and \p destTy as the destination
/// type in \p function, given \p node. In other words, creates something like
/// cast val to destTy.
/// cast val to destTy. \p isInput represents if this is converting an input.
virtual Node *createConversion(Function &function, const Node &node,
NodeValue &val, TypeRef destTy) = 0;
NodeValue &val, TypeRef destTy,
bool isInput) = 0;

/// Given a \p conversion, get its output value.
/// The default implementation returns the zero-th result.
Expand Down
5 changes: 4 additions & 1 deletion include/glow/Converter/TypeAToTypeBFunctionConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TypeAToTypeBFunctionConverter : public FunctionConverter {
/// Create a node in \p function that converts \p val to \p destTy, given
/// context \p node. \p val and \p destTy must have the same shape.
Node *createConversion(Function &function, const Node &node, NodeValue &val,
TypeRef destTy) override;
TypeRef destTy, bool isInput) override;

/// Check if \p node can be converted.
bool canConvert(const Node &node) const override;
Expand All @@ -69,6 +69,9 @@ class TypeAToTypeBFunctionConverter : public FunctionConverter {
/// \p precConfig.
TypeAToTypeBFunctionConverter(Function &F, ElemKind fromKind, ElemKind toKind,
const PrecisionConfiguration &precConfig);

/// Convert and clip all Storage nodes used by the function.
void convertAndClipStorage();
};
} // namespace glow
#endif
18 changes: 17 additions & 1 deletion include/glow/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,11 @@ class Function final : public Named {
/// based on the \p input type.
ClipNode *createClip(llvm::StringRef name, NodeValue input, float min,
float max);
/// @}

/// Creates and \returns a ClipNode to the min/max range of FP16 with \p name
/// of \p input. Result type will be implicitly set based on the \p input
/// type.
ClipNode *createClipMinMaxFP16(llvm::StringRef name, NodeValue input);

/// @name The builder functions below are identical to the builder functions
/// above except that they create nodes that use Placeholder instead of
Expand Down Expand Up @@ -1535,6 +1539,18 @@ SaveNode *getOutputSave(Function *F, Placeholder *PH);
/// currToNew.
Node *recursiveClone(Function *newF, Node *node, NodeMap &currToNew);

/// If \p PH is an output placeholder in the Function \p F,
/// \returns true.
/// This is determined by checking if the PH has a user which uses the PH as an
/// overwritten input.
bool isOutput(const Placeholder *PH, const Function &F);

/// If \p PH is an input placeholderin the Function \p F,
/// \returns true.
/// This is determined by checking if the PH is the input to a saveNode or is
/// used by a non saveNode.
bool isInput(const Placeholder *PH, const Function &F);

/// Helper vectors for common transpose shuffles.
#define NCHW2NHWC \
{ 0u, 2u, 3u, 1u }
Expand Down
7 changes: 4 additions & 3 deletions include/glow/Graph/NodeValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,13 @@ struct NodeValue {

/// Replace all of the uses in \p F of this value with \p v. Types of the node
/// value and \p v should be exactly the same.
void replaceAllUsesOfWith(NodeValue v, const Function *F = nullptr) const;
void replaceAllUsesOfWith(NodeValue v, const Function *F = nullptr,
Node *skipReplacement = nullptr) const;

/// Replace all of the uses in \p F of this value with \p v. Types of the node
/// value and \p v can be different.
void typeUnsafeReplaceAllUsesOfWith(NodeValue v,
const Function *F = nullptr) const;
void typeUnsafeReplaceAllUsesOfWith(NodeValue v, const Function *F = nullptr,
Node *skipReplacement = nullptr) const;

/// Return the TypeRef of the referenced return value.
TypeRef getType() const;
Expand Down
12 changes: 11 additions & 1 deletion include/glow/Optimizer/GraphOptimizer/CompilationContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,19 @@ struct PrecisionConfiguration {
/// Whether to convert UInt8FusedQTy to UInt8FusedFP16QTy in the Function.
bool convertFusedToFP16{false};

/// Whether to clip out-of-range FP values to the min/max of fp16.
/// If convertToFP16, whether to convert input Placeholders.
bool convertPlaceholdersToFP16{false};

/// If convertToFP16, whether to convert Constants.
bool convertConstantsToFP16{false};

/// If convertToFP16, whether to clip out-of-range FP values to the min/max of
/// fp16.
bool clipFP16{false};

/// If clipFP16, whether to skip clipping inputs of Nodes.
bool clipFP16SkipInputs{false};

/// Used during Quantization and convertToFP16 to keep the original precision
/// of specific node kinds (i.e. quantization/FP16 conversion would be skipped
/// for any node kinds found here). Used during profiling to prevent nodes
Expand Down
47 changes: 0 additions & 47 deletions lib/Backend/BackendUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,53 +237,6 @@ runtime::RuntimeBundle::getSymbolInfo(const Named *v) const {

namespace glow {

/// If \p PH is an output placeholder, \returns true.
/// This is determined by checking if the PH has a user which uses the PH as an
/// overwritten input.
bool isOutput(const Placeholder *PH, const Function &F) {
for (const auto &use : PH->getUsers()) {
// Look through the inputs of the PH's users. If an input is overwritten
// check if it's the PH, if it is return true.
auto *user = use.getUser();
// Consider only users inside the same function.
if (user->getParent() != &F) {
continue;
}
for (unsigned i = 0, numInputs = user->getNumInputs(); i < numInputs; i++) {
// If the input is not overwritten we can continue.
if (!user->isOverwrittenNthInput(i)) {
continue;
}
auto input = use.getUser()->getNthInput(i);
if (input.getNode() == PH) {
return true;
}
}
}
return false;
}

/// If \p PH is an input placeholder, \returns true.
bool isInput(const Placeholder *PH, const Function &F) {
// Check that the PH is the input to a saveNode or is used by a non saveNode.
for (const auto &use : PH->getUsers()) {
// Consider only users inside the same function.
if (use.getUser()->getParent() != &F) {
continue;
}
// Check if PH is an input to a saveNode.
if (auto *save = dyn_cast<SaveNode>(use.getUser())) {
auto input = save->getInput();
// If the PH is not an input to the saveNode we keep looking.
if (input.getNode() != PH) {
continue;
}
}
return true;
}
return false;
}

/// If \p PH is an output placeholder in the function \p F, \returns true.
/// This is determined by checking if the PH has a user which uses the PH as an
/// overwritten input.
Expand Down
5 changes: 5 additions & 0 deletions lib/Converter/Float16Converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ void glow::convertFunctionToFloat16(Function *F,
ElemKind::Float16Ty, precConfig);
if (precConfig.convertToFP16) {
converter.convert();

// Storage nodes are not converted + clipped directly -- they need to be
// converted via adding ConvertToNodes instead of directly setting their
// types like the TypeAToTypeBFunctionConverter does.
converter.convertAndClipStorage();
}

// Now we want to additionally convert all nodes with inputs in UInt8FusedQTy
Expand Down
9 changes: 6 additions & 3 deletions lib/Converter/FunctionConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ void FunctionConverter::convertOutputs(Node &node) {
// a save node, we actually have to convert the input.
if (saveNode && saveNode->getOutput() == val) {
NodeValue input = saveNode->getInput();
Node *conversion = createConversion(*parent, node, input, targetTy);
Node *conversion = createConversion(*parent, node, input, targetTy,
/* isInput */ false);
saveNode->setNthInput(SaveNode::InputIdx,
getConversionOutput(*conversion));
continue;
Expand All @@ -112,7 +113,8 @@ void FunctionConverter::convertOutputs(Node &node) {
auto conversionValIt = functionAndValToConversion.find(functionAndVal);
if (conversionValIt == functionAndValToConversion.end()) {
// Create the conversion.
Node *conversion = createConversion(*parent, node, val, origTy);
Node *conversion =
createConversion(*parent, node, val, origTy, /* isInput */ false);
// "conversion" uses val so after this call,
// we will get a use of conversion inside conversion.
NodeValue conversionVal = getConversionOutput(*conversion);
Expand Down Expand Up @@ -152,7 +154,8 @@ void FunctionConverter::convertInputs(Node &node) {
assert(targetTy->dims() == val.getType()->dims() &&
"Conversion does not preserve shape");
// Create the conversion.
Node *conversion = createConversion(function_, node, val, targetTy);
Node *conversion =
createConversion(function_, node, val, targetTy, /* isInput */ true);
node.setNthInput(idx, getConversionOutput(*conversion));
}
}
Expand Down
58 changes: 49 additions & 9 deletions lib/Converter/TypeAToTypeBFunctionConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,16 @@ TypeAToTypeBFunctionConverter::getTargetTypeForInput(const Node &use,
Node *TypeAToTypeBFunctionConverter::createConversion(Function &function,
const Node &node,
NodeValue &val,
TypeRef destTy) {
TypeRef destTy,
bool isInput) {
assert(((destTy->getElementType() == dstKind_ &&
val.getType()->getElementType() == srcKind_) ||
(destTy->getElementType() == srcKind_ &&
val.getType()->getElementType() == dstKind_)) &&
"Unexpected conversion type");

bool needClip = precConfig_.clipFP16;
bool needClip = dstKind_ == ElemKind::Float16Ty && precConfig_.clipFP16 &&
!(isInput && precConfig_.clipFP16SkipInputs);
if (needClip) {
switch (node.getKind()) {
case Kinded::Kind::ConcatNodeKind:
Expand All @@ -79,20 +81,15 @@ Node *TypeAToTypeBFunctionConverter::createConversion(Function &function,
assert((destTy->getElementType() == ElemKind::Float16Ty ||
val.getType()->getElementType() == ElemKind::Float16Ty) &&
"Unexpected conversion type");
constexpr float float16Max = 65504.0f;
constexpr float float16Min = -65504.0f;

// If the input is fp32 and output is fp16, then we want to do the convert
// before the clip. This way the clip can execute in fp16 mode.
if (destTy->getElementType() == ElemKind::Float16Ty &&
val.getType()->getElementType() == ElemKind::FloatTy) {
auto convert =
function.createConvertTo(val.getNode()->getName(), val, destTy);
return function.createClip(val.getNode()->getName(), convert, float16Min,
float16Max);
return function.createClipMinMaxFP16(val.getNode()->getName(), convert);
} else {
auto clip = function.createClip(val.getNode()->getName(), val, float16Min,
float16Max);
auto clip = function.createClipMinMaxFP16(val.getNode()->getName(), val);
return function.createConvertTo(val.getNode()->getName(), clip, destTy);
}
} else {
Expand All @@ -105,3 +102,46 @@ void TypeAToTypeBFunctionConverter::convertTensor(Tensor &tensor,
assert(destTy->getElementType() == dstKind_);
tensor.convertToType(dstKind_);
}

void convertAndClipStorageHelper(Storage &S, Function &F, bool clipFP16,
ElemKind srcKind, ElemKind dstKind) {
if (S.getOutput().getType()->getElementType() != srcKind) {
return;
}

ConvertToNode *convertToFP16 =
F.createConvertTo("convert to", S.getOutput(), dstKind);

NodeValue NV = convertToFP16->getResult();
if (clipFP16) {
NV = F.createClipMinMaxFP16(S.getName(), NV)->getResult();
}

// We have to convert back to the srcKind now as the users currently must be
// expecting FP32. The optimizer will remove if possible.
NodeValue convertBack =
F.createConvertTo("convert back", NV, srcKind)->getResult();

// We need to specify to skip replacing convertToFP16 here as otherwise we
// will create a cycle in the graph.
S.getOutput().replaceAllUsesOfWith(convertBack, &F, convertToFP16);
}

void TypeAToTypeBFunctionConverter::convertAndClipStorage() {
if (precConfig_.convertPlaceholdersToFP16) {
for (Placeholder *PH : function_.findPlaceholders()) {
// If the PH is not used as an input then we do not clip it.
if (!isInput(PH, function_)) {
continue;
}
convertAndClipStorageHelper(*PH, function_, precConfig_.clipFP16,
srcKind_, dstKind_);
}
}
if (precConfig_.convertConstantsToFP16) {
for (Constant *C : function_.findConstants()) {
convertAndClipStorageHelper(*C, function_, precConfig_.clipFP16, srcKind_,
dstKind_);
}
}
}
54 changes: 54 additions & 0 deletions lib/Graph/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2276,6 +2276,13 @@ ClipNode *Function::createClip(llvm::StringRef name, NodeValue input, float min,
return addNode(new ClipNode(name, input.getType(), input, min, max));
}

ClipNode *Function::createClipMinMaxFP16(llvm::StringRef name,
NodeValue input) {
constexpr float float16Min = -65504.0f;
constexpr float float16Max = 65504.0f;
return createClip(name, input, float16Min, float16Max);
}

//===----------------------------------------------------------------------===//
// Placeholder-builder methods.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -4267,6 +4274,53 @@ Node *glow::recursiveClone(Function *newF, Node *node, NodeMap &currToNew) {
}

namespace glow {
/// If \p PH is an output placeholder, \returns true.
/// This is determined by checking if the PH has a user which uses the PH as an
/// overwritten input.
bool isOutput(const Placeholder *PH, const Function &F) {
for (const auto &use : PH->getUsers()) {
// Look through the inputs of the PH's users. If an input is overwritten
// check if it's the PH, if it is return true.
auto *user = use.getUser();
// Consider only users inside the same function.
if (user->getParent() != &F) {
continue;
}
for (unsigned i = 0, numInputs = user->getNumInputs(); i < numInputs; i++) {
// If the input is not overwritten we can continue.
if (!user->isOverwrittenNthInput(i)) {
continue;
}
auto input = use.getUser()->getNthInput(i);
if (input.getNode() == PH) {
return true;
}
}
}
return false;
}

/// If \p PH is an input placeholder, \returns true.
bool isInput(const Placeholder *PH, const Function &F) {
// Check that the PH is the input to a saveNode or is used by a non saveNode.
for (const auto &use : PH->getUsers()) {
// Consider only users inside the same function.
if (use.getUser()->getParent() != &F) {
continue;
}
// Check if PH is an input to a saveNode.
if (auto *save = dyn_cast<SaveNode>(use.getUser())) {
auto input = save->getInput();
// If the PH is not an input to the saveNode we keep looking.
if (input.getNode() != PH) {
continue;
}
}
return true;
}
return false;
}

llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Module &mod) {
mod.dump(os);
return os;
Expand Down
Loading

0 comments on commit aab4944

Please sign in to comment.