Skip to content

Commit

Permalink
Merge pull request keras-team#1046 from julienr/visutil_to_graph
Browse files Browse the repository at this point in the history
Add visualize_util.to_graph and docs
  • Loading branch information
fchollet committed Nov 20, 2015
2 parents 6a231f1 + 71b258d commit e1cc291
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pages:
- Constraints: constraints.md
- Callbacks: callbacks.md
- Datasets: datasets.md
- Visualization: visualization.md
- Layers:
- Core Layers: layers/core.md
- Convolutional Layers: layers/convolutional.md
Expand Down
20 changes: 20 additions & 0 deletions docs/sources/visualization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

## Model visualization

The `keras.utils.visualize_util` module provides utility functions to plot
a Keras model (using graphviz).

This will plot a graph of the model and save it to a file:
```python
from keras.utils.visualize_util import plot
plot(model, to_file='model.png')
```

You can also directly obtain the `pydot.Graph` object and render it yourself,
for example to show it in an ipython notebook :
```python
from IPython.display import SVG
from keras.utils.visualize_util import to_graph

SVG(to_graph(model).create(prog='dot', format='svg'))
```
10 changes: 5 additions & 5 deletions keras/utils/visualize_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
# that works with python3 such as pydot2 or pydot
from keras.models import Sequential, Graph

def plot(model, to_file='model.png'):

def to_graph(model):
graph = pydot.Dot(graph_type='digraph')
if type(model) == Sequential:
previous_node = None
Expand All @@ -20,8 +19,6 @@ def plot(model, to_file='model.png'):
if previous_node:
graph.add_edge(pydot.Edge(previous_node, current_node))
previous_node = current_node
graph.write_png(to_file)

elif type(model) == Graph:
# don't need to append number for names since all nodes labeled
for input_node in model.input_config:
Expand All @@ -37,5 +34,8 @@ def plot(model, to_file='model.png'):
graph.add_edge(pydot.Edge(e, node['name']))
else:
graph.add_edge(pydot.Edge(node['input'], node['name']))
return graph

graph.write_png(to_file)
def plot(model, to_file='model.png'):
graph = to_graph(model)
graph.write_png(to_file)

0 comments on commit e1cc291

Please sign in to comment.