Skip to content

Commit

Permalink
Merge pull request ethereon#28 from Russell91/master
Browse files Browse the repository at this point in the history
Fixed run_mnist.py to correct name, finetune_mnist.py
  • Loading branch information
ethereon committed Apr 29, 2016
2 parents a4a773e + 8091c11 commit 99ac7f5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ First, you can convert a prototxt model to tensorflow code:

This produces tensorflow code for the LeNet network in mynet.py. The code can be imported as described below in the Inference section. Caffe-tensorflow also lets you convert .caffemodel weight files to .npy files that can be directly loaded from tensorflow:

$ ./convert.py examples/mnist/lenet.prototxt --data_path examples/mnist/lenet_iter_10000.caffemodel
$ ./convert.py examples/mnist/lenet.prototxt --caffemodel examples/mnist/lenet_iter_10000.caffemodel

The above command will generate a weight file named mynet.npy in addition to the mynet.py code.

#### Inference:

Once you have generated both the code weight files for LeNet, you can finetune LeNet using tensorflow with

$ ./examples/mnist/run_mnist.py
$ ./examples/mnist/finetune_mnist.py

At a high level, run_mnist.py works as follows:
At a high level, finetune_mnist.py works as follows:

```python
# Import the converted model's class
Expand Down
10 changes: 5 additions & 5 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from kaffe.tensorflow import TensorFlowTransformer

def convert(def_path,
data_path=None,
caffemodel=None,
data_output_path='mynet.npy',
code_output_path='mynet.py',
phase='test'):
try:
transformer = TensorFlowTransformer(def_path, data_path, phase=phase)
transformer = TensorFlowTransformer(def_path, caffemodel, phase=phase)
print('Converting data...')
if data_path is not None:
if caffemodel is not None:
data = transformer.transform_data()
print('Saving data...')
with open(data_output_path, 'wb') as data_out:
Expand All @@ -31,15 +31,15 @@ def convert(def_path,
def main():
parser = argparse.ArgumentParser()
parser.add_argument('def_path', help='Model definition (.prototxt) path')
parser.add_argument('--data_path', default=None, help='Model data (.caffemodel) path')
parser.add_argument('--caffemodel', default=None, help='Model data (.caffemodel) path')
parser.add_argument('--data_output_path', default='mynet.npy', help='Converted data output path')
parser.add_argument('--code_output_path', default='mynet.py',
help='Save generated source to this path')
parser.add_argument('-p', '--phase', default='test',
help='The phase to convert: test (default) or train')
args = parser.parse_args()

convert(args.def_path, args.data_path, args.data_output_path, args.code_output_path, args.phase)
convert(args.def_path, args.caffemodel, args.data_output_path, args.code_output_path, args.phase)

if __name__ == '__main__':
main()
2 changes: 0 additions & 2 deletions examples/mnist/convert.sh

This file was deleted.

0 comments on commit 99ac7f5

Please sign in to comment.