forked from minitorch/Module-0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgraph_builder.py
61 lines (51 loc) · 1.68 KB
/
graph_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import minitorch
import networkx as nx
def build_expression(code):
out = eval(
code,
{
"x": minitorch.Scalar(1.0, name="x"),
"y": minitorch.Scalar(1.0, name="y"),
"z": minitorch.Scalar(1.0, name="z"),
},
)
out.name = "out"
return out
class GraphBuilder:
def __init__(self):
self.op_id = 0
self.hid = 0
self.intermediates = {}
def get_name(self, x):
if not isinstance(x, minitorch.Variable):
return "constant %s" % (x,)
elif len(x.name) > 15:
if x.name in self.intermediates:
return "v%d" % (self.intermediates[x.name],)
else:
self.hid = self.hid + 1
self.intermediates[x.name] = self.hid
return "v%d" % (self.hid,)
else:
return x.name
def run(self, final):
queue = [[final]]
G = nx.MultiDiGraph()
G.add_node(self.get_name(final))
while queue:
(cur,) = queue[0]
queue = queue[1:]
if cur.is_leaf():
continue
else:
op = "%s (Op %d)" % (cur.history.last_fn.__name__, self.op_id)
G.add_node(op, shape="square", penwidth=3)
G.add_edge(op, self.get_name(cur))
self.op_id += 1
for i, input in enumerate(cur.history.inputs):
G.add_edge(self.get_name(input), op, f"{i}")
for input in cur.history.inputs:
if not isinstance(input, minitorch.Variable):
continue
queue.append([input])
return G