Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Sep 23, 2020
1 parent 069b355 commit 6505299
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
4 changes: 2 additions & 2 deletions lanenet_model/lanenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def __init__(self, phase, cfg):
self._net_flag = self._cfg.MODEL.FRONT_END

self._frontend = lanenet_front_end.LaneNetFrondEnd(
phase=phase, net_flag=self._net_flag
phase=phase, net_flag=self._net_flag, cfg=self._cfg
)
self._backend = lanenet_back_end.LaneNetBackEnd(
phase=phase
phase=phase, cfg=self._cfg
)

def inference(self, input_tensor, name, reuse=False):
Expand Down
12 changes: 5 additions & 7 deletions lanenet_model/lanenet_back_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,27 @@
"""
import tensorflow as tf

from local_utils.config_utils import parse_config_utils
from lanenet_model import lanenet_discriminative_loss
from semantic_segmentation_zoo import cnn_basenet

CFG = parse_config_utils.lanenet_cfg


class LaneNetBackEnd(cnn_basenet.CNNBaseModel):
"""
LaneNet backend branch which is mainly used for binary and instance segmentation loss calculation
"""
def __init__(self, phase):
def __init__(self, phase, cfg):
"""
init lanenet backend
:param phase: train or test
"""
super(LaneNetBackEnd, self).__init__()
self._cfg = cfg
self._phase = phase
self._is_training = self._is_net_for_training()

self._class_nums = CFG.DATASET.NUM_CLASSES
self._embedding_dims = CFG.MODEL.EMBEDDING_FEATS_DIMS
self._binary_loss_type = CFG.SOLVER.LOSS_TYPE
self._class_nums = self._cfg.DATASET.NUM_CLASSES
self._embedding_dims = self._cfg.MODEL.EMBEDDING_FEATS_DIMS
self._binary_loss_type = self._cfg.SOLVER.LOSS_TYPE

def _is_net_for_training(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion tools/test_lanenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_lanenet(image_path, weights_path):

input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')

net = lanenet.LaneNet(phase='test')
net = lanenet.LaneNet(phase='test', cfg=CFG)
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='LaneNet')

postprocessor = lanenet_postprocess.LaneNetPostProcessor()
Expand Down

0 comments on commit 6505299

Please sign in to comment.