Skip to content

Commit

Permalink
Add ReplaceNaN Node so backends can prevent lowering if desired
Browse files Browse the repository at this point in the history
  • Loading branch information
jfix71 committed Apr 26, 2019
1 parent d1f4f07 commit 4e64877
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 39 deletions.
9 changes: 3 additions & 6 deletions include/glow/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -568,12 +568,9 @@ class Function final : public Named {
/// element in \p input is NaN or not.
IsNaNNode *createIsNaN(llvm::StringRef name, NodeValue input);

/// Implements an operation that replaces all instances of NaN in \p input
/// with \p value. This operation is lowered to a Select node with \p input
/// as one of the inputs, a Splat node created using \p value as the other
/// input, and an IsNaN node as the comparator input.
/// \returns the Select node.
Node *createReplaceNaN(llvm::StringRef name, NodeValue input, float value);
/// \returns a ReplaceNaNNode given \p name, \p input, and \p value.
ReplaceNaNNode *createReplaceNaN(llvm::StringRef name, NodeValue input,
float value);

PowNode *createPow(llvm::StringRef name, NodeValue base, float exp);

Expand Down
15 changes: 3 additions & 12 deletions lib/Graph/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,18 +1172,9 @@ IsNaNNode *Function::createIsNaN(llvm::StringRef name, NodeValue input) {
return addNode(new IsNaNNode(name, OT, input));
}

Node *Function::createReplaceNaN(llvm::StringRef name, NodeValue input,
float value) {
// Create IsNaN node.
auto *INN = createIsNaN(name.str() + ".isNaN", input);

// Create Splat node.
auto *S = createSplat(name.str() + ".splat", input.getType(), value);

// Create Select node to pick between original and replacement values.
auto *SN = createSelect(name.str() + ".select", INN, S, input);

return SN;
ReplaceNaNNode *Function::createReplaceNaN(llvm::StringRef name,
NodeValue input, float value) {
return addNode(new ReplaceNaNNode(name, input.getType(), input, value));
}

PowNode *Function::createPow(llvm::StringRef name, NodeValue base, float exp) {
Expand Down
4 changes: 4 additions & 0 deletions lib/Graph/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,10 @@ bool IsNaNNode::verify() const {
return isValid;
}

bool ReplaceNaNNode::verify() const {
return checkSameType(getResult(), getInput(), this);
}

bool SelectNode::verify() const {
bool isValid = checkSameShape(getResult(), getLHS(), this);
isValid &= checkSameShape(getResult(), getRHS(), this);
Expand Down
21 changes: 21 additions & 0 deletions lib/Optimizer/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,25 @@ static void lowerBatchReduceMeanNode(Function *F, LoweredInfoMap *loweredMap,
replaceAllUsesOfWith(loweredMap, BRM.getResult(), DN);
}

/// Implement ReplaceNaN via a Select node with the input of \p RN as one of the
/// inputs, a Splat node created using value from \p RN as the other input, and
/// an IsNaN node as the comparator input.
static void lowerReplaceNaNNode(Function *F, LoweredInfoMap *loweredMap,
const ReplaceNaNNode &RN) {
// Create IsNaN node.
auto *INN = F->createIsNaN(RN.getName().str() + ".isNaN", RN.getInput());

// Create Splat node.
auto *S = F->createSplat(RN.getName().str() + ".splat",
RN.getInput().getType(), RN.getValue());

// Create Select node to pick between original and replacement values.
auto *SN =
F->createSelect(RN.getName().str() + ".select", INN, S, RN.getInput());

replaceAllUsesOfWith(loweredMap, RN.getResult(), SN);
}

/// Lowers \p node given Function \p. If \p loweredMap is not a nullptr, it will
/// log the lowering info of what was replaced by what via output names.
static void lowerNode(Function *F, Node *node, LoweredInfoMap *loweredMap) {
Expand Down Expand Up @@ -764,6 +783,8 @@ static void lowerNode(Function *F, Node *node, LoweredInfoMap *loweredMap) {
lowerTileNode(F, loweredMap, *TN);
} else if (auto *CSN = dyn_cast<ChannelShuffleNode>(node)) {
lowerChannelShuffleNode(F, loweredMap, *CSN);
} else if (auto *RN = dyn_cast<ReplaceNaNNode>(node)) {
lowerReplaceNaNNode(F, loweredMap, *RN);
}
}

Expand Down
16 changes: 7 additions & 9 deletions tests/unittests/Caffe2ImporterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -995,16 +995,14 @@ TEST(caffe2, replaceNaN) {
EXPECT_TRUE(output->dims().vec() == expectedDims);

// High level checks on the content of the graph.
// We have 1 IsNaN, 1 Splat, 1 Select and 1 Output.
EXPECT_EQ(F->getNodes().size(), 4);
// We have 1 ReplaceNaN and 1 Output.
EXPECT_EQ(F->getNodes().size(), 2);
auto *saveNode = getSaveNodeFromDest(output);
auto *selectNode = llvm::dyn_cast<SelectNode>(saveNode->getInput().getNode());
ASSERT_TRUE(selectNode);
auto *isNaNNode = llvm::dyn_cast<IsNaNNode>(selectNode->getCond().getNode());
ASSERT_TRUE(isNaNNode);
auto *splatNode = llvm::dyn_cast<SplatNode>(selectNode->getLHS().getNode());
ASSERT_TRUE(splatNode);
auto *inputNode = llvm::dyn_cast<Placeholder>(selectNode->getRHS().getNode());
auto *replaceNaNNode =
llvm::dyn_cast<ReplaceNaNNode>(saveNode->getInput().getNode());
EXPECT_EQ(replaceNaNNode->getValue(), 1.0f);
auto *inputNode =
llvm::dyn_cast<Placeholder>(replaceNaNNode->getInput().getNode());
ASSERT_EQ(inputNode, mod.getPlaceholderByName("input"));

// We have one input and one output.
Expand Down
24 changes: 12 additions & 12 deletions tests/unittests/OnnxImporterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1006,18 +1006,18 @@ TEST(onnx, importReplaceNaN) {
updateInputPlaceholdersByName(bindings, &mod, {"x"}, {&x});
}

// Verify structure: Input, IsNan, Splat -> Select -> Save.
ASSERT_EQ(mod.getPlaceholders().size(), 2);
ASSERT_EQ(F->getNodes().size(), 4);
auto *save = getSaveNodeFromDest(output);
auto *select = llvm::dyn_cast<SelectNode>(save->getInput().getNode());
ASSERT_TRUE(select);
auto *isNaN = llvm::dyn_cast<IsNaNNode>(select->getCond().getNode());
ASSERT_TRUE(isNaN);
auto *splat = llvm::dyn_cast<SplatNode>(select->getLHS().getNode());
ASSERT_TRUE(splat);
auto *input = llvm::dyn_cast<Placeholder>(select->getRHS().getNode());
ASSERT_EQ(input, mod.getPlaceholderByName("x"));
// Verify structure: Input -> ReplaceNaN -> Save.
EXPECT_EQ(F->getNodes().size(), 2);
auto *saveNode = getSaveNodeFromDest(output);
auto *replaceNaNNode =
llvm::dyn_cast<ReplaceNaNNode>(saveNode->getInput().getNode());
EXPECT_EQ(replaceNaNNode->getValue(), 1.0f);
auto *inputNode =
llvm::dyn_cast<Placeholder>(replaceNaNNode->getInput().getNode());
ASSERT_EQ(inputNode, mod.getPlaceholderByName("x"));

// We have one input and one output.
EXPECT_EQ(mod.getPlaceholders().size(), 2);
}

/// Test loading SparseToDense op from an ONNX model.
Expand Down
6 changes: 6 additions & 0 deletions tools/ClassGen/NodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,12 @@ int main(int argc, char **argv) {
"generates a mask that can be consumed by a Select node.");
// clang-format on

BB.newNode("ReplaceNaN")
.addInput("Input")
.addMember(MemberType::Float, "Value")
.addResultFromCtorArg()
.setDocstring("Replaces NaNs found in Input with Value.");

BB.newNode("Modulo")
.addInput("Input")
.addMember(MemberType::Int64, "Divisor")
Expand Down

0 comments on commit 4e64877

Please sign in to comment.