Skip to content

Commit

Permalink
Add traces to LowerGradOf and SpecializeAutoGrad
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#22599

Differential Revision: D16161144

Pulled By: Krovatkin

fbshipit-source-id: 9e206fcfb1796e9448e80f178b75d0c277bd348f
  • Loading branch information
Krovatkin authored and facebook-github-bot committed Jul 9, 2019
1 parent 0c2cd93 commit 50901be
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 1 deletion.
5 changes: 5 additions & 0 deletions torch/csrc/jit/jit_log.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/csrc/jit/jit_log.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/ir.h>
#include <cstdlib>
#include <sstream>

Expand All @@ -14,6 +15,10 @@ JitLoggingLevels jit_log_level() {
return log_level;
}

std::string debugValueOrDefault(const Node* n) {
return n->outputs().size() > 0 ? n->outputs().at(0)->debugName() : "n/a";
}

std::string jit_log_prefix(JitLoggingLevels level, const std::string& in_str) {
std::stringstream in_ss(in_str);
std::stringstream out_ss(in_str);
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/jit_log.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ enum class JitLoggingLevels {
GRAPH_DEBUG,
};

std::string debugValueOrDefault(const class Node* n);

JitLoggingLevels jit_log_level();

std::string jit_log_prefix(JitLoggingLevels level, const std::string& in_str);
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/passes/lower_grad_of.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/passes/lower_grad_of.h>
#include <torch/csrc/jit/jit_log.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -26,6 +27,11 @@ void LowerGradOf(Graph& g) {
else_block->registerOutput(undef);
if_stat->outputs().at(i)->copyMetadata(it->outputs().at(i));
}
GRAPH_UPDATE(
"Replacing node prim::GradOf w/ output ",
debugValueOrDefault(*it),
" with ",
debugValueOrDefault(if_stat));
it->replaceAllUsesWith(if_stat);
it.destroyCurrent();
}
Expand Down
31 changes: 30 additions & 1 deletion torch/csrc/jit/passes/specialize_autogradzero.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/symbolic_variable.h>

namespace torch {
Expand Down Expand Up @@ -40,6 +41,11 @@ void specializeAutogradZero(Graph& g) {
if (all_zeros) {
auto zero = g.createAutogradZero()->insertAfter(n)->output();
for (auto o : n->outputs()) {
GRAPH_UPDATE(
"Replacing output ",
o->debugName(),
" with AutogradZero ",
zero);
o->replaceAllUsesWith(zero);
}
} else {
Expand All @@ -58,14 +64,22 @@ void specializeAutogradZero(Graph& g) {
AT_ASSERT(state[input] != State::Unknown);
}
// hoist the nodes in the GradOf body to be before the linear block
GRAPH_UPDATE("Hoisting out prim::GradOf ", debugValueOrDefault(*it));
for (auto it = body->nodes().begin(); it != body->nodes().end();) {
auto block_node = *it++;
block_node->moveBefore(n);
}

for (size_t i = 0; i < n->outputs().size(); ++i)
for (size_t i = 0; i < n->outputs().size(); ++i) {
GRAPH_UPDATE(
"Replacing prim::GradOf's use ",
n->outputs().at(i)->debugName(),
" with hoisted value ",
body->outputs().at(i)->debugName());
n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
}
}
GRAPH_UPDATE("Destroying node ", debugValueOrDefault(*it));
it.destroyCurrent();
} break;
case prim::AutogradAdd: {
Expand All @@ -75,9 +89,19 @@ void specializeAutogradZero(Graph& g) {
if (state[a] == State::Zero) {
// Zero + b == b
n->output()->replaceAllUsesWith(b);
GRAPH_UPDATE(
"Simplifying prim::AutogradAdd(prim::AutogradZero, X) ",
n->output(),
" to ",
b);
it.destroyCurrent();
} else if (state[b] == State::Zero) {
// a + Zero == a
GRAPH_UPDATE(
"Simplifying prim::AutogradAdd(prim::AutogradZero, X) ",
n->output(),
" to ",
b);
n->output()->replaceAllUsesWith(a);
it.destroyCurrent();
} else if (state[a] == State::Nonzero && state[b] == State::Nonzero) {
Expand All @@ -86,6 +110,11 @@ void specializeAutogradZero(Graph& g) {
WithInsertPoint guard(n);
Value* new_add = toVar(a) + toVar(b);
state[new_add] = State::Nonzero;
GRAPH_UPDATE(
"Simplifying prim::AutogradAdd ",
n->output(),
" to prim::Add ",
new_add);
n->output()->replaceAllUsesWith(new_add);
it.destroyCurrent();
} else {
Expand Down

0 comments on commit 50901be

Please sign in to comment.