Skip to content

Commit

Permalink
add an option to plot network in other shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
wangg12 committed Nov 23, 2015
1 parent 87a28ca commit ff0de38
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 58 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,6 @@ R-package/inst/*
*.zip
*ubyte
*.bin

# ipython notebook
example/notebooks/.ipynb_checkpoints/*
103 changes: 46 additions & 57 deletions example/notebooks/simple_bind.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,82 +46,71 @@
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
"<!-- Generated by graphviz version 2.36.0 (20140111.2315)\n",
" -->\n",
"<!-- Title: plot Pages: 1 -->\n",
"<svg width=\"102pt\" height=\"611pt\"\n",
" viewBox=\"0.00 0.00 102.00 611.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 607)\">\n",
"<svg width=\"144pt\" height=\"506pt\"\n",
" viewBox=\"0.00 0.00 144.00 506.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 502)\">\n",
"<title>plot</title>\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-607 98,-607 98,4 -4,4\"/>\n",
"<!-- data -->\n",
"<g id=\"node1\" class=\"node\"><title>data</title>\n",
"<polygon fill=\"#8dd3c7\" stroke=\"black\" points=\"94,-58 -7.10543e-15,-58 -7.10543e-15,-0 94,-0 94,-58\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-25.3\" font-family=\"Times,serif\" font-size=\"14.00\">data</text>\n",
"</g>\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-502 140,-502 140,4 -4,4\"/>\n",
"<!-- fc1 -->\n",
"<g id=\"node2\" class=\"node\"><title>fc1</title>\n",
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-167 -7.10543e-15,-167 -7.10543e-15,-109 94,-109 94,-167\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-141.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-126.8\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n",
"</g>\n",
"<!-- fc1&#45;&gt;data -->\n",
"<g id=\"edge1\" class=\"edge\"><title>fc1&#45;&gt;data</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-98.5824C47,-85.2841 47,-70.632 47,-58.2967\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-108.887 42.5001,-98.887 47,-103.887 47.0001,-98.887 47.0001,-98.887 47.0001,-98.887 47,-103.887 51.5001,-98.8871 47,-108.887 47,-108.887\"/>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-79.8\" font-family=\"Times,serif\" font-size=\"14.00\">784</text>\n",
"<g id=\"node1\" class=\"node\"><title>fc1</title>\n",
"<ellipse fill=\"#fb8072\" stroke=\"black\" cx=\"68\" cy=\"-29\" rx=\"68.2532\" ry=\"29\"/>\n",
"<text text-anchor=\"middle\" x=\"68\" y=\"-32.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n",
"<text text-anchor=\"middle\" x=\"68\" y=\"-17.8\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n",
"</g>\n",
"<!-- bn1 -->\n",
"<g id=\"node3\" class=\"node\"><title>bn1</title>\n",
"<polygon fill=\"#bebada\" stroke=\"black\" points=\"94,-276 -7.10543e-15,-276 -7.10543e-15,-218 94,-218 94,-276\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-243.3\" font-family=\"Times,serif\" font-size=\"14.00\">BatchNorm</text>\n",
"<g id=\"node2\" class=\"node\"><title>bn1</title>\n",
"<ellipse fill=\"#bebada\" stroke=\"black\" cx=\"68\" cy=\"-139\" rx=\"47\" ry=\"29\"/>\n",
"<text text-anchor=\"middle\" x=\"68\" y=\"-135.3\" font-family=\"Times,serif\" font-size=\"14.00\">BatchNorm</text>\n",
"</g>\n",
"<!-- bn1&#45;&gt;fc1 -->\n",
"<g id=\"edge2\" class=\"edge\"><title>bn1&#45;&gt;fc1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-207.582C47,-194.284 47,-179.632 47,-167.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-217.887 42.5001,-207.887 47,-212.887 47.0001,-207.887 47.0001,-207.887 47.0001,-207.887 47,-212.887 51.5001,-207.887 47,-217.887 47,-217.887\"/>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-188.8\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n",
"<g id=\"edge1\" class=\"edge\"><title>bn1&#45;&gt;fc1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M68,-99.8131C68,-86.1516 68,-71.0092 68,-58.3283\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"68,-109.906 63.5001,-99.9062 68,-104.906 68.0001,-99.9062 68.0001,-99.9062 68.0001,-99.9062 68,-104.906 72.5001,-99.9062 68,-109.906 68,-109.906\"/>\n",
"<text text-anchor=\"middle\" x=\"78.5\" y=\"-80.3\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n",
"</g>\n",
"<!-- act1 -->\n",
"<g id=\"node4\" class=\"node\"><title>act1</title>\n",
"<polygon fill=\"#ffffb3\" stroke=\"black\" points=\"94,-385 -7.10543e-15,-385 -7.10543e-15,-327 94,-327 94,-385\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-359.8\" font-family=\"Times,serif\" font-size=\"14.00\">Activation</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-344.8\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n",
"<g id=\"node3\" class=\"node\"><title>act1</title>\n",
"<ellipse fill=\"#ffffb3\" stroke=\"black\" cx=\"68\" cy=\"-249\" rx=\"48.4635\" ry=\"29\"/>\n",
"<text text-anchor=\"middle\" x=\"68\" y=\"-252.8\" font-family=\"Times,serif\" font-size=\"14.00\">Activation</text>\n",
"<text text-anchor=\"middle\" x=\"68\" y=\"-237.8\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n",
"</g>\n",
"<!-- act1&#45;&gt;bn1 -->\n",
"<g id=\"edge3\" class=\"edge\"><title>act1&#45;&gt;bn1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-316.582C47,-303.284 47,-288.632 47,-276.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-326.887 42.5001,-316.887 47,-321.887 47.0001,-316.887 47.0001,-316.887 47.0001,-316.887 47,-321.887 51.5001,-316.887 47,-326.887 47,-326.887\"/>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-297.8\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n",
"<g id=\"edge2\" class=\"edge\"><title>act1&#45;&gt;bn1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M68,-209.813C68,-196.152 68,-181.009 68,-168.328\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"68,-219.906 63.5001,-209.906 68,-214.906 68.0001,-209.906 68.0001,-209.906 68.0001,-209.906 68,-214.906 72.5001,-209.906 68,-219.906 68,-219.906\"/>\n",
"<text text-anchor=\"middle\" x=\"78.5\" y=\"-190.3\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n",
"</g>\n",
"<!-- fc2 -->\n",
"<g id=\"node5\" class=\"node\"><title>fc2</title>\n",
"<polygon fill=\"#fb8072\" stroke=\"black\" points=\"94,-494 -7.10543e-15,-494 -7.10543e-15,-436 94,-436 94,-494\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-468.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-453.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n",
"<g id=\"node4\" class=\"node\"><title>fc2</title>\n",
"<ellipse fill=\"#fb8072\" stroke=\"black\" cx=\"68\" cy=\"-359\" rx=\"68.2532\" ry=\"29\"/>\n",
"<text text-anchor=\"middle\" x=\"68\" y=\"-362.8\" font-family=\"Times,serif\" font-size=\"14.00\">FullyConnected</text>\n",
"<text text-anchor=\"middle\" x=\"68\" y=\"-347.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n",
"</g>\n",
"<!-- fc2&#45;&gt;act1 -->\n",
"<g id=\"edge4\" class=\"edge\"><title>fc2&#45;&gt;act1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-425.582C47,-412.284 47,-397.632 47,-385.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-435.887 42.5001,-425.887 47,-430.887 47.0001,-425.887 47.0001,-425.887 47.0001,-425.887 47,-430.887 51.5001,-425.887 47,-435.887 47,-435.887\"/>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-406.8\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n",
"<g id=\"edge3\" class=\"edge\"><title>fc2&#45;&gt;act1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M68,-319.813C68,-306.152 68,-291.009 68,-278.328\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"68,-329.906 63.5001,-319.906 68,-324.906 68.0001,-319.906 68.0001,-319.906 68.0001,-319.906 68,-324.906 72.5001,-319.906 68,-329.906 68,-329.906\"/>\n",
"<text text-anchor=\"middle\" x=\"78.5\" y=\"-300.3\" font-family=\"Times,serif\" font-size=\"14.00\">128</text>\n",
"</g>\n",
"<!-- softmax -->\n",
"<g id=\"node6\" class=\"node\"><title>softmax</title>\n",
"<polygon fill=\"#b3de69\" stroke=\"black\" points=\"94,-603 -7.10543e-15,-603 -7.10543e-15,-545 94,-545 94,-603\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-570.3\" font-family=\"Times,serif\" font-size=\"14.00\">Softmax</text>\n",
"<g id=\"node5\" class=\"node\"><title>softmax</title>\n",
"<ellipse fill=\"#fccde5\" stroke=\"black\" cx=\"68\" cy=\"-469\" rx=\"55.0152\" ry=\"29\"/>\n",
"<text text-anchor=\"middle\" x=\"68\" y=\"-465.3\" font-family=\"Times,serif\" font-size=\"14.00\">SoftmaxOutput</text>\n",
"</g>\n",
"<!-- softmax&#45;&gt;fc2 -->\n",
"<g id=\"edge5\" class=\"edge\"><title>softmax&#45;&gt;fc2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M47,-534.582C47,-521.284 47,-506.632 47,-494.297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"47,-544.887 42.5001,-534.887 47,-539.887 47.0001,-534.887 47.0001,-534.887 47.0001,-534.887 47,-539.887 51.5001,-534.887 47,-544.887 47,-544.887\"/>\n",
"<text text-anchor=\"middle\" x=\"54\" y=\"-515.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n",
"<g id=\"edge4\" class=\"edge\"><title>softmax&#45;&gt;fc2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M68,-429.813C68,-416.152 68,-401.009 68,-388.328\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"68,-439.906 63.5001,-429.906 68,-434.906 68.0001,-429.906 68.0001,-429.906 68.0001,-429.906 68,-434.906 72.5001,-429.906 68,-439.906 68,-439.906\"/>\n",
"<text text-anchor=\"middle\" x=\"75\" y=\"-410.3\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.dot.Digraph at 0x7f9c55b02048>"
"<graphviz.dot.Digraph at 0x7fd757479190>"
]
},
"execution_count": 2,
Expand All @@ -140,7 +129,7 @@
"# visualize the network\n",
"batch_size = 100\n",
"data_shape = (batch_size, 784)\n",
"mx.viz.plot_network(softmax, shape={\"data\":data_shape})"
"mx.viz.plot_network(softmax, shape={\"data\":data_shape}, node_attrs={\"shape\":'oval',\"fixedsize\":'false'})"
]
},
{
Expand Down Expand Up @@ -454,21 +443,21 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 2",
"language": "python",
"name": "python3"
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.2"
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
Expand Down
11 changes: 10 additions & 1 deletion python/mxnet/visualization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding: utf-8
# pylint: disable=invalid-name, too-many-locals, fixme
# pylint: disable=too-many-branches, too-many-statements
# pylint: disable=dangerous-default-value
"""Visualization module"""
from __future__ import absolute_import

Expand All @@ -25,7 +26,7 @@ def _str2tuple(string):
return re.findall(r"\d+", string)


def plot_network(symbol, title="plot", shape=None):
def plot_network(symbol, title="plot", shape=None, node_attrs={}):
"""convert symbol to dot object for visualization
Parameters
Expand All @@ -36,6 +37,11 @@ def plot_network(symbol, title="plot", shape=None):
symbol to be visualized
shape: dict
dict of shapes, str->shape (tuple), given input shapes
node_attrs: dict
dict of node's attributes
for example:
node_attrs={"shape":"oval","fixedsize":"fasle"}
means to plot the network in "oval"
Returns
------
dot: Diagraph
Expand All @@ -59,8 +65,11 @@ def plot_network(symbol, title="plot", shape=None):
conf = json.loads(symbol.tojson())
nodes = conf["nodes"]
heads = set([x[0] for x in conf["heads"]]) # TODO(xxx): check careful
# default attributes of node
node_attr = {"shape": "box", "fixedsize": "true",
"width": "1.3", "height": "0.8034", "style": "filled"}
# merge the dcit provided by user and the default one
node_attr.update(node_attrs)
dot = Digraph(name=title)
# color map
cm = ("#8dd3c7", "#fb8072", "#ffffb3", "#bebada", "#80b1d3",
Expand Down

0 comments on commit ff0de38

Please sign in to comment.