Skip to content

Commit

Permalink
Batch more matrix multiplies (pytorch#13456)
Browse files Browse the repository at this point in the history
Summary:
This handles the input pre-multiplication in RNNs, yielding pretty significant speedups in backward times. This pass depends on loop unrolling, so we'll batch only as many elements as the unrolling factor allows.

cc mruberry ngimel zou3519 zdevito
Pull Request resolved: pytorch#13456

Differential Revision: D12920339

Pulled By: zou3519

fbshipit-source-id: 5bcd6d259c054a6dea02ae09a9fdf9f030856443
  • Loading branch information
apaszke authored and facebook-github-bot committed Nov 26, 2018
1 parent 1ef9490 commit a603689
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 35 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ namespace c10 {
_(prim, ConstantChunk) \
_(prim, NoneGenerator) \
_(prim, MMTreeReduce) \
_(prim, MMBatchSide) \
_(aten, floordiv) \
_(aten, __round_to_zero_floordiv)\
_(prim, fork) \
Expand Down
25 changes: 12 additions & 13 deletions test/expect/TestScript.test_milstm_fusion_cuda-backward.expect
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,16 @@ graph(%0 : Float(*, *)
%outgate : Float(*, *)
%27 : Float(*, *)) {
%28 : Float(*, *) = prim::FusionGroup_0(%ingate, %forgetgate, %cellgate, %outgate, %12, %1, %27, %0)
%29 : Float(*, *) = aten::mul(%28, %Uz)
%30 : Float(*, *) = aten::mul(%28, %Wx)
%31 : Float(*, *) = prim::FusionGroup_1(%28, %22, %13)
%32 : Float(*, *), %33 : Float(*, *) = prim::FusionGroup_2(%Wx, %15, %28, %Uz, %14)
%29 : Float(*, *) = aten::mul(%28, %Wx)
%30 : Float(*, *) = prim::FusionGroup_1(%28, %22, %13)
%31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *) = prim::FusionGroup_2(%Wx, %15, %28, %14, %Uz)
%34 : Float(*, *) = aten::t(%16)
%35 : Float(*, *) = aten::mm(%34, %31)
%35 : Float(*, *) = aten::mm(%34, %30)
%36 : Float(*, *) = aten::t(%35)
%37 : Float(*, *) = aten::t(%17)
%38 : Float(*, *) = aten::mm(%37, %33)
%38 : Float(*, *) = aten::mm(%37, %32)
%39 : Float(*, *) = aten::t(%38)
return (%28, %29, %30, %32, %36, %39);
return (%28, %33, %29, %31, %36, %39);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Float(*, *)
Expand Down Expand Up @@ -94,14 +93,14 @@ with prim::FusionGroup_1 = graph(%0 : Float(*, *)
with prim::FusionGroup_2 = graph(%0 : Float(*, *)
%1 : Float(*)
%2 : Float(*, *)
%3 : Float(*, *)
%4 : Float(*)) {
%3 : Float(*)
%4 : Float(*, *)) {
%5 : Float(*, *) = aten::mul(%2, %4)
%6 : Float(*, *) = aten::mul(%2, %3)
%7 : Float(*, *) = aten::mul(%6, %1)
%7 : Float(*, *) = aten::mul(%5, %1)
%8 : int = prim::Constant[value=1]()
%9 : int = prim::Constant[value=1]()
%10 : Float(*, *) = aten::add(%5, %7, %9)
%11 : Float(*, *) = aten::mul(%6, %0)
return (%11, %10);
%10 : Float(*, *) = aten::add(%6, %7, %9)
%11 : Float(*, *) = aten::mul(%5, %0)
return (%11, %10, %5);
}
39 changes: 32 additions & 7 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def canonical(graph):
return str(torch._C._jit_pass_canonicalize(graph))


def get_lstm_inputs(device, training=False):
input = torch.randn(3, 10, dtype=torch.float, device=device, requires_grad=training)
def get_lstm_inputs(device, training=False, seq_length=None):
input_shape = (3, 10) if seq_length is None else (seq_length, 3, 10)
input = torch.randn(*input_shape, dtype=torch.float, device=device, requires_grad=training)
hx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
cx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
module = nn.LSTMCell(10, 20).to(device, torch.float) # Just to allocate weights with correct sizes
Expand Down Expand Up @@ -174,19 +175,19 @@ def get_execution_plan(graph_executor_state):
return execution_plans[0]


def get_grad_executor(plan_state):
if len(list(plan_state.graph.nodes())) != 1:
def get_grad_executor(plan_state, diff_graph_idx=None):
if diff_graph_idx is None and len(list(plan_state.graph.nodes())) != 1:
raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
grad_executors = list(plan_state.code.grad_executors())
return grad_executors[0]
return grad_executors[diff_graph_idx or 0]


def backward_graph(script_module):
def backward_graph(script_module, diff_graph_idx=None):
if not isinstance(script_module, torch.jit.ScriptModule):
raise RuntimeError('Expected ScriptModule')
ge_state = script_module.get_debug_state()
fwd_plan = get_execution_plan(ge_state)
grad_executor = get_grad_executor(fwd_plan)
grad_executor = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
bwd_plan = get_execution_plan(grad_executor.get_debug_state())
# Running JIT passes requires that we own the graph (with a shared_ptr).
# The debug state struct does not own its graph so we make a copy of it.
Expand Down Expand Up @@ -6643,6 +6644,30 @@ def func(a):
self.run_pass('erase_number_types', graph)
self.assertExpectedGraph(graph)

def test_mm_batching(self):
lstm_cell = torch.jit.script(LSTMCellS)

def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
for i in range(x.size(0)):
hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
return hx

slstm = torch.jit.script(lstm)

inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
slstm(*inputs).sum().backward()

fw_graph = slstm.graph_for(*inputs)
bw_graph = backward_graph(slstm, diff_graph_idx=0)
self.assertTrue('prim::MMBatchSide' in str(fw_graph))
self.assertTrue('prim::MMTreeReduce' in str(bw_graph))

sout = slstm(*inputs)
out = lstm(*inputs)
self.assertEqual(slstm(*inputs), lstm(*inputs))
self.assertEqual(torch.autograd.grad(slstm(*inputs).sum(), inputs),
torch.autograd.grad(lstm(*inputs).sum(), inputs))

def test_loop_unrolling(self):
def fn(x):
y = 0
Expand Down
13 changes: 11 additions & 2 deletions torch/csrc/jit/graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,13 +456,22 @@ struct GraphExecutorImpl {
}

void runOptimization(std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
// Basic graph preprocessing to eliminate noise.
EliminateDeadCode(graph);
EliminateCommonSubexpression(graph);
ConstantPooling(graph);
UnrollLoops(graph);

PeepholeOptimize(graph);
CheckInplace(graph);

// Unroll small loops, and eliminate expressions that are the same at every
// iteration.
UnrollLoops(graph);
EliminateCommonSubexpression(graph);

// Rewrite subgraphs with many MMs into expressions that batch them.
BatchMM(graph);

CheckInplace(graph);
}

void runNondiffOptimization(std::shared_ptr<Graph>& graph) {
Expand Down
18 changes: 15 additions & 3 deletions torch/csrc/jit/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -916,15 +916,23 @@ Node* Node::insertAfter(Node * n) {
}

bool Node::moveAfterTopologicallyValid(Node* n, const AliasDb& aliasDb) {
return tryMove(n, MoveSide::AFTER, aliasDb);
return tryMove(n, MoveSide::AFTER, aliasDb, /*dryRun=*/false);
}

bool Node::couldMoveAfterTopologically(Node* n, const AliasDb& aliasDb) {
return tryMove(n, MoveSide::AFTER, aliasDb, /*dryRun=*/true);
}

bool Node::moveBeforeTopologicallyValid(Node* n, const AliasDb& aliasDb) {
// We have to distinguish the move side (instead of just moving after
// n->prev()). Consider the following example:
// If the dependency graph looks like this -> n -> o then moveBefore(o) will
// end up with [this, o, n], but moveAfter(n) will return false.
return tryMove(n, MoveSide::BEFORE, aliasDb);
return tryMove(n, MoveSide::BEFORE, aliasDb, /*dryRun=*/false);
}

bool Node::couldMoveBeforeTopologically(Node* n, const AliasDb& aliasDb) {
return tryMove(n, MoveSide::BEFORE, aliasDb, /*dryRun=*/true);
}

// Helper for topologically-safe node moves. See `tryMove()` for details.
Expand Down Expand Up @@ -1104,7 +1112,7 @@ struct WorkingSet {
// node at a time. When we can't move past a node (because it depends on the
// working set), then add it to the working set and keep moving until we hit
// `moveAfter`.
bool Node::tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb) {
bool Node::tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb, bool dryRun) {
JIT_ASSERT(this->inBlockList() && movePoint->inBlockList());
JIT_ASSERT(this->owningBlock() == movePoint->owningBlock());
if (this == movePoint) {
Expand Down Expand Up @@ -1165,6 +1173,10 @@ bool Node::tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb) {
return false;
}

if (dryRun) {
return true;
}

// 3. Execute the move
JIT_ASSERT(curNode == movePoint);
if (splitThisAndDeps) {
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,10 @@ struct Node : public Attributes<Node> {
// violating dependencies, otherwise executes the move and returns `true`
TORCH_API bool moveAfterTopologicallyValid(Node* n, const AliasDb& aliasDb);

// Like moveAfterTopologicallyValid, but only returns if the move is
// possible, without actually performing it.
TORCH_API bool couldMoveAfterTopologically(Node* n, const AliasDb& aliasdb);

// Move a node 'n' (already in the graph) before 'this' in the topological
// order.
//
Expand All @@ -510,6 +514,10 @@ struct Node : public Attributes<Node> {
// violating dependencies, otherwise executes the move and returns `true`
TORCH_API bool moveBeforeTopologicallyValid(Node* n, const AliasDb& aliasDb);

// Like moveBeforeTopologicallyValid, but only returns if the move is
// possible, without actually performing it.
TORCH_API bool couldMoveBeforeTopologically(Node* n, const AliasDb& aliasDb);

// Remove the input at 'i' from this node.
//
// WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
Expand Down Expand Up @@ -589,7 +597,7 @@ struct Node : public Attributes<Node> {

private:
enum class MoveSide { BEFORE, AFTER };
bool tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb);
bool tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb, bool dryRun);
void move(Node* movePoint, MoveSide moveSide);

std::pair<Value*, const Argument&> findInput(Symbol name);
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/jit/passes/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ void AliasDb::analyze(Node* node) {
case prim::TupleConstruct:
case prim::Undefined:
case prim::FusedConcat:
case prim::MMTreeReduce:
case prim::MMBatchSide:
return analyzeCreator(node);
case prim::TupleUnpack:
case prim::TupleIndex:
Expand Down Expand Up @@ -354,7 +356,9 @@ void AliasDb::analyzeSubgraph(Node* node) {

// For nodes that generate a fresh value from nothing
void AliasDb::analyzeCreator(Node* node) {
giveFreshAlias(node->output());
for (Value * output : node->outputs()) {
giveFreshAlias(output);
}
}

// For nodes that extract values from a composite type. Right now, this just
Expand Down
Loading

0 comments on commit a603689

Please sign in to comment.