forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmerge_matmul.py
171 lines (134 loc) · 5.84 KB
/
merge_matmul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
from torch.fx.node import Node
from torch.fx._symbolic_trace import symbolic_trace
from torch.fx.passes.tools_common import legalize_graph
import itertools
import operator
from typing import Dict, List, Tuple
def split_result_tensors(
result: torch.Tensor, inputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
"""
A free function for use in the merge_matmul graph transformation below that
splits the output from a merged matmul into the individual results for each
input tensor.
Arguments:
result: The merged matmul result tensor.
inputs: The list of inputs that were merged into one for the matmul.
Returns:
List of matmul results for each input tensor.
"""
# When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we
# need an int even when tracing
if isinstance(result, torch.fx.Proxy):
splits = [0] * len(inputs)
else:
splits = [x.shape[0] for x in inputs]
return torch.split(result, splits)
def may_depend_on(a: Node, b: Node, search_depth: int = 6):
"""
Determine if one node depends on another in a torch.fx.Graph.
Arguments:
a: The node that may have a dependency on b.
b: The node that a may have a dependency on.
search_depth: In the case of an indirect dependency, this function
searches upto this many nodes away in search of a
data dependency. If none is found, the function
makes the conservative assumption that there is a
dependency.
Returns:
True if a may depend on b, False if it definitely does not.
"""
# Equivalence is defined as dependence.
if a == b:
return True
# If a has no inputs, it cannot depend on b.
if len(a.all_input_nodes) == 0:
return False
# If the search depth has been exhausted and no conclusion has been
# reached, assume that there is a data dependency.
if search_depth == 0:
return True
# Recursively check all inputs of a.
for inp in a.all_input_nodes:
if may_depend_on(inp, b, search_depth - 1):
return True
return False
def are_nodes_independent(nodes: List[Node]):
"""
Check if all of the given nodes are pairwise-data independent.
Arguments:
nodes: The nodes to check for data dependencies.
Returns:
True if any pair in nodes has a data dependency.
"""
# For each pair in nodes:
for i, j in itertools.combinations(nodes, 2):
if may_depend_on(i, j) or may_depend_on(j, i):
return False
return True
def merge_matmul(in_mod: torch.nn.Module):
"""
A graph transformation that merges matrix multiplication operations that share the same right-hand
side operand into one large matrix multiplication.
____ _________ _________
---- | | | | M| A * C |
M| A | T| B | * K| C | = |---------|
---- , | | | | T| B * C |
K ---- --------- ---------
K R R
"""
gm = symbolic_trace(in_mod)
rhs_users: Dict[Node, List[Node]] = {}
lhs_users: Dict[Node, List[Node]] = {}
# Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
# the matmul of which they are the LHS/RHS.
for node in gm.graph.nodes:
if node.op != "call_function" or node.target is not torch.matmul:
continue
lhs, rhs = node.args
# TODO: Properly handle aliasing caused by get_attr. For now,
# use the attribute name as the operand if the node is a
# get_attr.
lhs = lhs.target if lhs.op == "get_attr" else lhs
rhs = rhs.target if rhs.op == "get_attr" else rhs
lhs_users.setdefault(lhs, []).append(node)
rhs_users.setdefault(rhs, []).append(node)
for rhs, mms in rhs_users.items():
# There must be at least matmuls for a merge to make sense.
if len(mms) < 2:
continue
# All matmuls must not depend on each other directly or indirectly
# in order for the merge to be possible.
if not are_nodes_independent(mms):
continue
lhs_vals = [mm.args[0] for mm in mms]
# Merge the matmul.
# Collect a list of LHS operands and the single RHS operand.
lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs
# Concatenate all the LHS operands.
merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})
# Multiply the concatenated LHS operands with the one RHS. This will produce
# the same results as all the individual matmuls involving rhs in the original graph,
# but they will all be concatenated together.
merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {})
# Split the result of the merged matmul using the shapes of the LHS operands
# to ascertain how large each chunk should be.
merge_mm_split = gm.graph.call_function(
split_result_tensors, (merge_mm, lhs), {}
)
merge_mm_res = [
gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
for out in range(len(lhs))
]
# Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
for old, new in zip(mms, merge_mm_res):
old.replace_all_uses_with(new)
gm.graph.erase_node(old)
# All of the new nodes created above were inserted at the end, so we need to sort
# the nodes topologically to make sure all definitions precede uses.
legalize_graph(gm)
gm.recompile()
gm.graph.lint()
return gm