Skip to content

Commit

Permalink
update mnn freeze model tool
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Sep 23, 2020
1 parent 6505299 commit 5841528
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions mnn_project/freeze_lanenet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
import tensorflow as tf

from lanenet_model import lanenet
from local_utils.config_utils import parse_config_utils

MODEL_WEIGHTS_FILE_PATH = './test.ckpt'
OUTPUT_PB_FILE_PATH = './lanenet.pb'
CFG = parse_config_utils.lanenet_cfg


def init_args():
Expand All @@ -45,16 +47,22 @@ def convert_ckpt_into_pb_file(ckpt_file_path, pb_file_path):
with tf.variable_scope('lanenet'):
input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')

net = lanenet.LaneNet(phase='test', net_flag='vgg')
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_model')
net = lanenet.LaneNet(phase='test', cfg=CFG)
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='LaneNet')

with tf.variable_scope('lanenet/'):
binary_seg_ret = tf.cast(binary_seg_ret, dtype=tf.float32)
binary_seg_ret = tf.squeeze(binary_seg_ret, axis=0, name='final_binary_output')
instance_seg_ret = tf.squeeze(instance_seg_ret, axis=0, name='final_pixel_embedding_output')

# define moving average version of the learned variables for eval
with tf.variable_scope(name_or_scope='moving_avg'):
variable_averages = tf.train.ExponentialMovingAverage(
CFG.SOLVER.MOVING_AVE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()

# create a session
saver = tf.train.Saver()
saver = tf.train.Saver(variables_to_restore)

sess_config = tf.ConfigProto()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.85
Expand Down

0 comments on commit 5841528

Please sign in to comment.