Skip to content

Commit

Permalink
Add a guide to the GraphDef file format for tool developers.
Browse files Browse the repository at this point in the history
Change: 113413158
  • Loading branch information
petewarden authored and Vijay Vasudevan committed Jan 30, 2016
1 parent b2a64c3 commit 6b83549
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tensorflow/g3doc/how_tos/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,11 @@ different locations in the model construction code.
The "Variable Scope" mechanism is designed to facilitate that.

[View Tutorial](../how_tos/variable_scope/index.md)

## A Tool Developer's Guide to TensorFlow Model Files

If you're developing a tool to load, analyze, or manipulate TensorFlow model
files, it's useful to understand a bit about the format in which they're stored.
This guide covers the details of the saved model format.

[View Tutorial](../how_tos/tool_developers/index.md)
198 changes: 198 additions & 0 deletions tensorflow/g3doc/how_tos/tool_developers/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# A Tool Developer's Guide to TensorFlow Model Files

Most users shouldn't need to care about the internal details of how TensorFlow
stores data on disk, but you might if you're a tool developer. For example, you
may want to analyze models, or convert back and forth between TensorFlow and
other formats. This guide tries to explain some of the details of how you can
work with the main files that hold model data, to make it easier to develop
those kind of tools.

[TOC]

## Protocol Buffers

All of TensorFlow's file formats are based on [Protocol Buffers]
(https://developers.google.com/protocol-buffers/?hl=en), so to start
it's worth getting familiar with how they work. The summary is that you define
data structures in text files, and the protobuf tools generate classes in C,
Python, and other languages that can load, save, and access the data in a
friendly way. We often refer to Protocol Buffers as protobufs, and I'll use
that convention in this guide.

## GraphDef

The foundation of computation in TensorFlow is the `Graph` object. This holds a
network of nodes, each representing one operation, connected to each other as
inputs and outputs. After you've created a `Graph` object, you can save it out
by calling `as_graph_def()`, which returns a `GraphDef` object.

The GraphDef class is an object created by the ProtoBuf library from the
definition in
[tensorflow/core/framework/graph.proto](https://github.com/tensorflow/tensorflow
/blob/master/tensorflow/core/framework/graph.proto). The protobuf tools parse
this text file, and generate the code to load, store, and manipulate graph
definitions. If you see a standalone TensorFlow file representing a model, it's
likely to contain a serialized version of one of these `GraphDef` objects
saved out by the protobuf code.

This generated code is used to save and load the GraphDef files from disk. A
good example to look at as we dig into this is
[graph_metrics.py](https://github.com/tensorflow/tensorflow/blob/master/tensorfl
ow/python/tools/graph_metrics.py). This Python script takes a saved graph
definition, and analyzes the model to estimate performance and resource
statistics. The code that actually loads the model looks like this:

```python
graph_def = graph_pb2.GraphDef()
```

This line creates an empty `GraphDef` object, the class that's been created
from the textual definition in graph.proto. This is the object we're going to
populate with the data from our file.

```python
with open(FLAGS.graph, "rb") as f:
```

Here we get a file handle for the path we've passed in to the script

```python
if FLAGS.input_binary:
graph_def.ParseFromString(f.read)
else:
text_format.Merge(f.read(), graph_def)
```

## Text or Binary?

There are actually two different formats that a ProtoBuf can be saved in.
TextFormat is a human-readable form, which makes it nice for debugging and
editing, but can get large when there's numerical data like weights stored in
it. You can see a small example of that in
[poly5-graph.pbtxt](https://github.com/tensorflow/tensorflow/blob/master/tensorf
low/tensorboard/components/tf-tensorboard/demo/data/poly5-graph.pbtxt).

Binary format files are a lot smaller than their text equivalents, even though
they're not as readable for us. In this script, we ask the user to supply a
flag indicating whether the input file is binary or text, so we know the right
function to call. You can find an example of a large binary file inside the
[inception_dec_2015.zip
archive](https://storage.googleapis.com/download.tensorflow.org/models/inception
_dec_2015.zip), as `tensorflow_inception_graph.pb`.

The API itself can be a bit confusing - the binary call is actually
`ParseFromString()`, whereas you use a utility function from the `text_format`
module to load textual files.

## Nodes

Once you've loaded a file into the `graph_def` variable, you can now access the
data inside it. For most practical purposes, the important section is the list
of nodes stored in the node member. Here's the code that loops through those:

```python
for node in graph_def.node
```

Each node is a `NodeDef` object, also defined in graph.proto. These are the
fundamental building blocks of TensorFlow graphs, with each one defining a
single operation along with its input connections. Here are the members of a
`NodeDef`, and what they mean.

### `name`

Every node should have a unique identifier that's not used by any other nodes
in the graph. If you don't specify one as you're building a graph using the
Python API, one reflecting the name of operation, such as "MatMul",
concatenated with a monotonically increasing number, such as "5", will be
picked for you. an arbitrary one will be picked for you. The name is used when
defining the connections between nodes, and when setting inputs and outputs for
the whole graph when it's run.

### `op`

This defines what operation to run, for example `"Add"`, `"MatMul"`, or
`"Conv2D"`. When a graph is run, this op name is looked up in a registry to
find an implementation. The registry is populated by calls to the
`REGISTER_OP()` macro, like those in
[tensorflow/core/ops/nn_ops.cc](https://github.com/tensorflow/tensorflow/blob/ma
ster/tensorflow/core/ops/nn_ops.cc).

### `input`

A list of strings, each one of which is the name of another node, optionally
followed by a colon and an output port number. For example, a node with two
inputs might have a list like `["some_node_name", "another_node_name"]`, which
is equivalent to `["some_node_name:0", "another_node_name:0"]`, and defines the
node's first input as the first output from the node with the name
`"some_node_name"`, and a second input from the first output of
`"another_node_name

### `device`

In most cases you can ignore this, since it defines where to run a node in a
distributed environment, or when you want to force the operation onto CPU or
GPU.

### `attr`

This is a key/value store holding all the attributes of a node. These are the
permanent properties of nodes, things that don't change at runtime such as the
size of filters for convolutions, or the values of constant ops. Because there
can be so many different types of attribute values, from strings, to ints, to
arrays of tensor values, there's a separate protobuf file defining the data
structure that holds them, in
[tensorflow/core/framework/attr_value.proto](https://github.com/tensorflow/tenso
rflow/blob/master/tensorflow/core/framework/attr_value.proto).

Each attribute has a unique name string, and the expected attributes are listed
when the operation is defined. If an attribute isn't present in a node, but it
has a default listed in the operation definition, that default is used when the
graph is created.

You can access all of these members by calling `node.name`, `node.op`, etc. in
Python. The list of nodes stored in the `GraphDef` is a full definition of the
model architecture.

## Freezing

One confusing part about this is that the weights usually aren't stored inside
the file format during training. Instead, they're held in separate checkpoint
files, and there are `Variable` ops in the graph that load the latest values
when they're initialized. It's often not very convenient to have separate files
when you're deploying to production, so there's the
[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflo
w/python/tools/freeze_graph.py) script that takes a graph definition and a set
of checkpoints and freezes them together into a single file.

What this does is load the `GraphDef`, pull in the values for all the variables
from the latest checkpoint file, and then replace each `Variable` op with a
`Const` that has the numerical data for the weights stored in its attributes
It then strips away all the extraneous nodes that aren't used for forward
inference, and saves out the resulting `GraphDef` into an output file.

## Weight Formats

If you're dealing with TensorFlow models that represent neural networks, one of
the most common problems is extracting and interpreting the weight values. A
common way to store them, for example in graphs created by the freeze_graph
script, is as `Const` ops containing the weights as `Tensors`. These are
defined in
[tensorflow/core/framework/tensor.proto](https://github.com/tensorflow/tensorflo
w/blob/master/tensorflow/core/framework/tensor.proto), and contain information
about the size and type of the data, as well as the values themselves. In
Python, you get a `TensorProto` object from a `NodeDef` representing a `Const`
op by calling something like `some_node_def.attr['value'].tensor`.

This will give you an object representing the weights data. The data itself
will be stored in one of the lists with the suffix _val as indicated by the
type of the object, for example `float_val` for 32-bit float data types.

The ordering of convolution weight values is often tricky to deal with when
converting between different frameworks. In TensorFlow, the filter weights for
the `Conv2D` operation are stored on the second input, and are expected to be
in the order `[filter_height, filter_width, input_depth, output_depth]`, where
filter_count increasing by one means moving to an adjacent value in memory.

Hopefully this rundown gives you a better idea of what's going on inside
TensorFlow model files, and will help you if you ever need to manipulate them.

0 comments on commit 6b83549

Please sign in to comment.