Skip to content

Commit

Permalink
Manual revert of D27369251 (pytorch#56080)
Browse files Browse the repository at this point in the history
Summary:
Fixes #{issue number}

Pull Request resolved: pytorch#56080

Reviewed By: hansonw

Differential Revision: D27777498

Pulled By: Krovatkin

fbshipit-source-id: f72ca725ceba3c1fbd54c30014ac001d4b35b9eb
  • Loading branch information
Krovatkin authored and facebook-github-bot committed Apr 15, 2021
1 parent f8d331b commit 92a09fb
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 122 deletions.
61 changes: 0 additions & 61 deletions test/jit/test_autodiff_subgraph_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, disable_autodiff_subgraph_inlining
from torch.testing import FileCheck
from torch.testing._internal.common_utils import num_profiled_runs

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
Expand Down Expand Up @@ -49,66 +48,6 @@ def func(x):
output = func(input, profile_and_replay=True)
self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])

@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_differentiable_graph_ops_requires_grad(self):
x = torch.randn(8, 2, dtype=torch.float).requires_grad_()
y = torch.randn(8, 2, dtype=torch.float)

def t(x : torch.Tensor, y : torch.Tensor):
o = x + 1.0
o1 = torch.relu(o)
o = y + 1.5
o2 = torch.relu(o)
o3 = o1 + o2

_ = o1.add_(1.0)
_ = o2.add_(1.0)
o = o1 * 1.0
oo1 = torch.relu(o)
o = o2 * 2.0
oo2 = torch.relu(o)
oo3 = oo1 + oo2
return o1, o2, o3, oo1, oo2, oo3

with enable_profiling_mode_for_profiling_tests():

t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)

FileCheck().check("prim::DifferentiableGraph").run(t_jit.graph_for(x, y))
# validate the differentiableGraphOps are marking proper requires_grad
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
self.assertEqual(oo, jit_oo)
# one more runs to trigger fusion
jit_o = t_jit(x, y)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
self.assertEqual(oo, jit_oo)

@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, "Simple Executor doesn't support gradients")
def test_prune_grad(self):
@torch.jit.script
def t(input, bias):
return torch.nn.functional.relu(input + bias)
input = torch.randn(2, 8, requires_grad=True)
bias = torch.randn(8, requires_grad=False) # bias does NOT require grad
NUM_PROFILED_RUNS = 1
with num_profiled_runs(NUM_PROFILED_RUNS):
WARMUP = 3 # 2 runs to reach backward + 1 to optimize it
for x in range(WARMUP):
o = t(input, bias)
o.sum().backward()

fwd_plan = list(t.get_debug_state().execution_plans.values())[0]
bwd_graph = list(fwd_plan.code.grad_executor_states()[0].execution_plans.values())[0].graph
tup = next(bwd_graph.outputs())
self.assertEqual(len(list(tup.node().inputs())), 1)

def test_simple_merge(self):
# o --> o
def fn(x, y, z):
Expand Down
61 changes: 0 additions & 61 deletions torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,66 +129,6 @@ static bool needsGradientInProfilingMode(Block* b) {
return false;
}

// `prim::RequiresGradCheck` guarantees that requires_grad properties
// of input tensors will match the profiled, otherwise a fallback path
// will be triggered. This allow us to prune off gradients in backward
// graph for inputs that don't need gradients. We transfer requires_grad
// properties from inputs to the `prim::DifferentiableGraph` onto inputs to the
// differentiable graph. Autodiff will inspect these properties and prune
// off gradients that aren't required
// `requires_grad` properties from `dnode->outputs()` will also be transferred
static void setRequiresGradOnDiffGraph(Node* dnode) {
auto gi = dnode->g(attr::Subgraph)->inputs();
for (size_t i = 0; i < dnode->inputs().size(); i++) {
if (auto ty = dnode->input(i)->type()->cast<TensorType>()) {
auto gi_ty = gi[i]->type()->expect<TensorType>();
gi[i]->setType(gi_ty->withRequiresGrad(ty->requires_grad()));
GRAPH_DEBUG(
"Setting ",
*gi_ty->withRequiresGrad(ty->requires_grad()),
" on ",
gi[i],
" ",
gi[i]->debugName());
}
}

// We also need to put requires_grad on outputs within subgraph, so autodiff
// can set df_input_vjps and DifferentiableGraphOp can set `requires_grad=`
// properly
auto go = dnode->g(attr::Subgraph)->outputs();
for (size_t i = 0; i < go.size(); i++) {
auto ty = go[i]->type()->cast<TensorType>();
if (ty) {
auto n = go[i]->node();
auto dno = dnode->outputs().at(i);
auto dno_use0 = dno->uses().at(0);
GRAPH_DEBUG("found first user of ", i, " as ", *dno_use0.user);
if (n->kind() == prim::profile) {
GRAPH_DEBUG(
"setting output ", i, " to type ", *n->ty(attr::profiled_type));
go[i]->setType(n->ty(attr::profiled_type));
} else if (dno_use0.user->kind() == prim::profile) {
GRAPH_DEBUG(
"setting output ",
i,
" to type ",
*dno_use0.user->ty(attr::profiled_type));
go[i]->setType(dno_use0.user->ty(attr::profiled_type));
} else if (dno_use0.user->kind() == prim::DifferentiableGraph) {
Value* o =
dno_use0.user->g(attr::Subgraph)->inputs().at(dno_use0.offset);
auto nn = o->uses().at(0).user;
if (nn->kind() == prim::profile) {
GRAPH_DEBUG(
"setting output ", i, " to type ", *nn->ty(attr::profiled_type));
go[i]->setType(nn->ty(attr::profiled_type));
}
}
}
}
}

bool guardDifferentiableGraph(Node* dnode) {
auto gi = dnode->g(attr::Subgraph)->inputs();
bool all_inputs_seen = true;
Expand Down Expand Up @@ -222,7 +162,6 @@ bool guardDifferentiableGraph(Node* dnode) {
}
}
if (all_inputs_seen) {
setRequiresGradOnDiffGraph(dnode);
// we may have seen both true and false for requires_grad. In this case
// we guard with true here and the other case is in the fallback. This
// will give us trouble when we get "alternating patterns" of gradients
Expand Down

0 comments on commit 92a09fb

Please sign in to comment.