diff --git a/torch/csrc/jit/jit_log.cpp b/torch/csrc/jit/jit_log.cpp index 4ec5c230c1f5b1..0c22f5e9bd3d5b 100644 --- a/torch/csrc/jit/jit_log.cpp +++ b/torch/csrc/jit/jit_log.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -14,6 +15,10 @@ JitLoggingLevels jit_log_level() { return log_level; } +std::string debugValueOrDefault(const Node* n) { + return n->outputs().size() > 0 ? n->outputs().at(0)->debugName() : "n/a"; +} + std::string jit_log_prefix(JitLoggingLevels level, const std::string& in_str) { std::stringstream in_ss(in_str); std::stringstream out_ss(in_str); diff --git a/torch/csrc/jit/jit_log.h b/torch/csrc/jit/jit_log.h index c135df6911a36b..8aafc413181e65 100644 --- a/torch/csrc/jit/jit_log.h +++ b/torch/csrc/jit/jit_log.h @@ -19,6 +19,8 @@ enum class JitLoggingLevels { GRAPH_DEBUG, }; +std::string debugValueOrDefault(const class Node* n); + JitLoggingLevels jit_log_level(); std::string jit_log_prefix(JitLoggingLevels level, const std::string& in_str); diff --git a/torch/csrc/jit/passes/lower_grad_of.cpp b/torch/csrc/jit/passes/lower_grad_of.cpp index 944bef07fc0b39..0d67ec9b8bb0ff 100644 --- a/torch/csrc/jit/passes/lower_grad_of.cpp +++ b/torch/csrc/jit/passes/lower_grad_of.cpp @@ -1,4 +1,5 @@ #include +#include namespace torch { namespace jit { @@ -26,6 +27,11 @@ void LowerGradOf(Graph& g) { else_block->registerOutput(undef); if_stat->outputs().at(i)->copyMetadata(it->outputs().at(i)); } + GRAPH_UPDATE( + "Replacing node prim::GradOf w/ output ", + debugValueOrDefault(*it), + " with ", + debugValueOrDefault(if_stat)); it->replaceAllUsesWith(if_stat); it.destroyCurrent(); } diff --git a/torch/csrc/jit/passes/specialize_autogradzero.cpp b/torch/csrc/jit/passes/specialize_autogradzero.cpp index 245668c6727e6b..ceb2266a0ee261 100644 --- a/torch/csrc/jit/passes/specialize_autogradzero.cpp +++ b/torch/csrc/jit/passes/specialize_autogradzero.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace torch { @@ -40,6 +41,11 @@ void specializeAutogradZero(Graph& g) { if (all_zeros) { auto zero = g.createAutogradZero()->insertAfter(n)->output(); for (auto o : n->outputs()) { + GRAPH_UPDATE( + "Replacing output ", + o->debugName(), + " with AutogradZero ", + zero); o->replaceAllUsesWith(zero); } } else { @@ -58,14 +64,22 @@ void specializeAutogradZero(Graph& g) { AT_ASSERT(state[input] != State::Unknown); } // hoist the nodes in the GradOf body to be before the linear block + GRAPH_UPDATE("Hoisting out prim::GradOf ", debugValueOrDefault(*it)); for (auto it = body->nodes().begin(); it != body->nodes().end();) { auto block_node = *it++; block_node->moveBefore(n); } - for (size_t i = 0; i < n->outputs().size(); ++i) + for (size_t i = 0; i < n->outputs().size(); ++i) { + GRAPH_UPDATE( + "Replacing prim::GradOf's use ", + n->outputs().at(i)->debugName(), + " with hoisted value ", + body->outputs().at(i)->debugName()); n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i)); + } } + GRAPH_UPDATE("Destroying node ", debugValueOrDefault(*it)); it.destroyCurrent(); } break; case prim::AutogradAdd: { @@ -75,9 +89,19 @@ void specializeAutogradZero(Graph& g) { if (state[a] == State::Zero) { // Zero + b == b n->output()->replaceAllUsesWith(b); + GRAPH_UPDATE( + "Simplifying prim::AutogradAdd(prim::AutogradZero, X) ", + n->output(), + " to ", + b); it.destroyCurrent(); } else if (state[b] == State::Zero) { // a + Zero == a + GRAPH_UPDATE( + "Simplifying prim::AutogradAdd(prim::AutogradZero, X) ", + n->output(), + " to ", + b); n->output()->replaceAllUsesWith(a); it.destroyCurrent(); } else if (state[a] == State::Nonzero && state[b] == State::Nonzero) { @@ -86,6 +110,11 @@ void specializeAutogradZero(Graph& g) { WithInsertPoint guard(n); Value* new_add = toVar(a) + toVar(b); state[new_add] = State::Nonzero; + GRAPH_UPDATE( + "Simplifying prim::AutogradAdd ", + n->output(), + " to prim::Add ", + new_add); n->output()->replaceAllUsesWith(new_add); it.destroyCurrent(); } else {