From ff0de38039c13b8a6a3ce418d6035ee353da23b2 Mon Sep 17 00:00:00 2001 From: wangg12 Date: Mon, 23 Nov 2015 15:44:59 +0800 Subject: [PATCH] add an option to plot network in other shapes --- .gitignore | 3 + example/notebooks/simple_bind.ipynb | 103 +++++++++++++--------------- python/mxnet/visualization.py | 11 ++- 3 files changed, 59 insertions(+), 58 deletions(-) diff --git a/.gitignore b/.gitignore index d476228d5e8a..7ca2c56f9ef2 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,6 @@ R-package/inst/* *.zip *ubyte *.bin + +# ipython notebook +example/notebooks/.ipynb_checkpoints/* diff --git a/example/notebooks/simple_bind.ipynb b/example/notebooks/simple_bind.ipynb index efa1afa497a0..444b30b0d2d1 100644 --- a/example/notebooks/simple_bind.ipynb +++ b/example/notebooks/simple_bind.ipynb @@ -46,82 +46,71 @@ "\n", "\n", - "\n", "\n", - "\n", - "\n", + "\n", + "\n", "plot\n", - "\n", - "\n", - "data\n", - "\n", - "data\n", - "\n", + "\n", "\n", - "fc1\n", - "\n", - "FullyConnected\n", - "128\n", - "\n", - "\n", - "fc1->data\n", - "\n", - "\n", - "784\n", + "fc1\n", + "\n", + "FullyConnected\n", + "128\n", "\n", "\n", - "bn1\n", - "\n", - "BatchNorm\n", + "bn1\n", + "\n", + "BatchNorm\n", "\n", "\n", - "bn1->fc1\n", - "\n", - "\n", - "128\n", + "bn1->fc1\n", + "\n", + "\n", + "128\n", "\n", "\n", - "act1\n", - "\n", - "Activation\n", - "tanh\n", + "act1\n", + "\n", + "Activation\n", + "tanh\n", "\n", "\n", - "act1->bn1\n", - "\n", - "\n", - "128\n", + "act1->bn1\n", + "\n", + "\n", + "128\n", "\n", "\n", - "fc2\n", - "\n", - "FullyConnected\n", - "10\n", + "fc2\n", + "\n", + "FullyConnected\n", + "10\n", "\n", "\n", - "fc2->act1\n", - "\n", - "\n", - "128\n", + "fc2->act1\n", + "\n", + "\n", + "128\n", "\n", "\n", - "softmax\n", - "\n", - "Softmax\n", + "softmax\n", + "\n", + "SoftmaxOutput\n", "\n", "\n", - "softmax->fc2\n", - "\n", - "\n", - "10\n", + "softmax->fc2\n", + "\n", + "\n", + "10\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 2, @@ -140,7 +129,7 @@ "# visualize the network\n", "batch_size = 100\n", "data_shape = (batch_size, 784)\n", - "mx.viz.plot_network(softmax, shape={\"data\":data_shape})" + "mx.viz.plot_network(softmax, shape={\"data\":data_shape}, node_attrs={\"shape\":'oval',\"fixedsize\":'false'})" ] }, { @@ -454,21 +443,21 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 2", "language": "python", - "name": "python3" + "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 3 + "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.4.2" + "pygments_lexer": "ipython2", + "version": "2.7.10" } }, "nbformat": 4, diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 0fcc35894759..5a8dea96023e 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -1,6 +1,7 @@ # coding: utf-8 # pylint: disable=invalid-name, too-many-locals, fixme # pylint: disable=too-many-branches, too-many-statements +# pylint: disable=dangerous-default-value """Visualization module""" from __future__ import absolute_import @@ -25,7 +26,7 @@ def _str2tuple(string): return re.findall(r"\d+", string) -def plot_network(symbol, title="plot", shape=None): +def plot_network(symbol, title="plot", shape=None, node_attrs={}): """convert symbol to dot object for visualization Parameters @@ -36,6 +37,11 @@ def plot_network(symbol, title="plot", shape=None): symbol to be visualized shape: dict dict of shapes, str->shape (tuple), given input shapes + node_attrs: dict + dict of node's attributes + for example: + node_attrs={"shape":"oval","fixedsize":"fasle"} + means to plot the network in "oval" Returns ------ dot: Diagraph @@ -59,8 +65,11 @@ def plot_network(symbol, title="plot", shape=None): conf = json.loads(symbol.tojson()) nodes = conf["nodes"] heads = set([x[0] for x in conf["heads"]]) # TODO(xxx): check careful + # default attributes of node node_attr = {"shape": "box", "fixedsize": "true", "width": "1.3", "height": "0.8034", "style": "filled"} + # merge the dcit provided by user and the default one + node_attr.update(node_attrs) dot = Digraph(name=title) # color map cm = ("#8dd3c7", "#fb8072", "#ffffb3", "#bebada", "#80b1d3",