From 15180b55e57ebdb172f4d9d2acad7aaec8826cfc Mon Sep 17 00:00:00 2001 From: wenlihaoyu <wenyinlong52@163.com> Date: Wed, 4 Sep 2019 00:03:36 +0800 Subject: [PATCH] add keras model to tf pb model for opencv dnn --- tools/keras_to_pb.py | 143 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 tools/keras_to_pb.py diff --git a/tools/keras_to_pb.py b/tools/keras_to_pb.py new file mode 100644 index 0000000..48736f5 --- /dev/null +++ b/tools/keras_to_pb.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Sep 4 00:00:43 2019 +keras to pd for opencv dnn +@author: chineseocr +""" + + +import tensorflow as tf +import os +from keras import backend as K +from tensorflow.python.framework import graph_util,graph_io +def keras_to_pb(kerasmodel,outputDir,modelName='model.pd',outName = "output_"): + if not os.path.exists(outputDir): + os.makedirs(outputDir) + out_nodes = [] + for i in range(len(kerasmodel.outputs)): + out_nodes.append(outName + str(i + 1)) + tf.identity(kerasmodel.outputs[i],outName + str(i + 1)) + sess = K.get_session() + init_graph = sess.graph.as_graph_def() + main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes) + graph_io.write_graph(main_graph,outputDir,name = modelName,as_text = False) + + +def pd_to_pbtxt(pdPath): + with tf.gfile.FastGFile(pdPath,'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + + for i in reversed(range(len(graph_def.node))): + if graph_def.node[i].op == 'Const': + del graph_def.node[i] + for attr in ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim', + 'use_cudnn_on_gpu', 'Index', 'Tperm', 'is_training', + 'Tpaddings']: + if attr in graph_def.node[i].attr: + del graph_def.node[i].attr[attr] + + path,filename = os.path.split(pdPath) + filename = filename.replace('.pb','.pbtxt') + tf.train.write_graph(graph_def,path, filename, as_text=True) + + +def remove_node(txt,name): + index = txt.find('node {\n '+ name ) + ind = index + punc = [] + flag=False + if index>=0: + for i in range(ind,len(txt)): + if txt[i]=='{': + punc.append('{') + elif txt[i]=='}': + if '{' in punc: + punc.pop(-1) + if len(punc)==0: + flag=True + break + + + + if flag: + txt=txt.replace(txt[index:i+1],'') + + return txt + + + +if __name__=='__main__': + ## demo vgg16 to pb + + """ + Open model.pbtxt and remove nodes with names strided_slice,flatten/Shape, flatten/strided_slice, flatten/Prod, flatten/stack. + Replace the node + node { + name: "flatten/Reshape" + op: "Reshape" + input: "block5_pool/MaxPool" + input: "flatten/stack" + } + + on + + node { + name: "flatten/Reshape" + op: "Flatten" + input: "block5_pool/MaxPool" + } + + """ + def pbtxt_adjust(pbtxt): + with open(pbtxt) as f: + txt = f.read() + + nodename = 'node {\n name: "flatten/Reshape"' + ## replace + index = txt.find(nodename) + if index>0: + for ind in range(index,len(txt)): + if txt[ind]=='}': + break + replacestr = txt[index:ind+1] + txt=txt.replace(replacestr,replacestr.replace('op: "Reshape"','op: "Flatten"')) + + + ## del node + delnamelist = [ 'name: "flatten/Shape"', + 'name: "flatten/strided_slice"', + 'name: "flatten/Prod"', + 'name: "flatten/stack"' + ] + + for delname in delnamelist: + txt = remove_node(txt,delname) + + with open(pbtxt,'w') as f: + f.write(txt) + + + from keras.applications.vgg16 import VGG16 + import cv2 + import numpy as np + vgg = VGG16(weights=None) + name = vgg.name + modelName = name+'.pb' + outputDir=os.path.join('/tmp/',name) + keras_to_pb(vgg,outputDir,modelName=modelName,outName = "output_") + pb = os.path.join(outputDir,modelName) + pd_to_pbtxt(pb) + pbtxt = os.path.join(outputDir,name+'.pbtxt') + pbtxt_adjust(pbtxt) + + dnn = cv2.dnn.readNetFromTensorflow(pb,pbtxt) + inputBlob = np.zeros((1,3,224,224)) + dnn.setInput(inputBlob) + pred = dnn.forward() + print('dnn:',pred[0][:10]) + print('vgg:',vgg.predict(np.zeros((1,224,224,3)))[0][:10]) + + + \ No newline at end of file