Skip to content

Commit

Permalink
merged conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
MSiam committed Dec 29, 2017
2 parents 6444b83 + fdeaf16 commit 38ab3c1
Show file tree
Hide file tree
Showing 28 changed files with 1,448 additions and 83 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ data/full_cityscapes/
pretrained_weights/*.npy
pretrained_weights/*.npz
pretrained_weights/*.t7
pretrained_weights/*.net
pretrained_weights/*.pkl
pretrained_weights/linknet/

# imgs
*.png
Expand All @@ -142,3 +145,5 @@ pretrained_weights/*.t7
!data/data_for_test_n_overfit/Y.npy
cityscapesScripts/cityscapesscripts/annotation/icons/
cityscapesScripts/cityscapesscripts/viewer/icons/
data/cityscapes_tfdata/images/
data/cityscapes_tfdata/labels/
159 changes: 159 additions & 0 deletions Variables.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
<tf.Variable 'network/conv1_x/conv1/weights:0' shape=(7, 7, 3, 64) dtype=float32_ref>
<tf.Variable 'network/conv1_x/bn1/mu:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv1_x/bn1/sigma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv1_x/bn1/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv1_x/bn1/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_1/conv_1/weights:0' shape=(3, 3, 64, 64) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_1/bn_1/mu:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_1/bn_1/sigma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_1/bn_1/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_1/bn_1/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_1/conv_2/weights:0' shape=(3, 3, 64, 64) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_1/bn_2/mu:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_1/bn_2/sigma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_1/bn_2/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_1/bn_2/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_2/conv_1/weights:0' shape=(3, 3, 64, 64) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_2/bn_1/mu:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_2/bn_1/sigma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_2/bn_1/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_2/bn_1/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_2/conv_2/weights:0' shape=(3, 3, 64, 64) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_2/bn_2/mu:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_2/bn_2/sigma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_2/bn_2/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv2_x/conv2_2/bn_2/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_1/shortcut_conv/weights:0' shape=(1, 1, 64, 128) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_1/conv_1/weights:0' shape=(3, 3, 64, 128) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_1/bn_1/mu:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_1/bn_1/sigma:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_1/bn_1/beta:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_1/bn_1/gamma:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_1/conv_2/weights:0' shape=(3, 3, 128, 128) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_1/bn_2/mu:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_1/bn_2/sigma:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_1/bn_2/beta:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_1/bn_2/gamma:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_2/conv_1/weights:0' shape=(3, 3, 128, 128) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_2/bn_1/mu:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_2/bn_1/sigma:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_2/bn_1/beta:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_2/bn_1/gamma:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_2/conv_2/weights:0' shape=(3, 3, 128, 128) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_2/bn_2/mu:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_2/bn_2/sigma:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_2/bn_2/beta:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv3_x/conv3_2/bn_2/gamma:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_1/shortcut_conv/weights:0' shape=(1, 1, 128, 256) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_1/conv_1/weights:0' shape=(3, 3, 128, 256) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_1/bn_1/mu:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_1/bn_1/sigma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_1/bn_1/beta:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_1/bn_1/gamma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_1/conv_2/weights:0' shape=(3, 3, 256, 256) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_1/bn_2/mu:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_1/bn_2/sigma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_1/bn_2/beta:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_1/bn_2/gamma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_2/conv_1/weights:0' shape=(3, 3, 256, 256) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_2/bn_1/mu:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_2/bn_1/sigma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_2/bn_1/beta:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_2/bn_1/gamma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_2/conv_2/weights:0' shape=(3, 3, 256, 256) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_2/bn_2/mu:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_2/bn_2/sigma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_2/bn_2/beta:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv4_x/conv4_2/bn_2/gamma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_1/shortcut_conv/weights:0' shape=(1, 1, 256, 512) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_1/conv_1/weights:0' shape=(3, 3, 256, 512) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_1/bn_1/mu:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_1/bn_1/sigma:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_1/bn_1/beta:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_1/bn_1/gamma:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_1/conv_2/weights:0' shape=(3, 3, 512, 512) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_1/bn_2/mu:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_1/bn_2/sigma:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_1/bn_2/beta:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_1/bn_2/gamma:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_2/conv_1/weights:0' shape=(3, 3, 512, 512) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_2/bn_1/mu:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_2/bn_1/sigma:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_2/bn_1/beta:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_2/bn_1/gamma:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_2/conv_2/weights:0' shape=(3, 3, 512, 512) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_2/bn_2/mu:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_2/bn_2/sigma:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_2/bn_2/beta:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/conv5_x/conv5_2/bn_2/gamma:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/conv_1/conv2d/kernel:0' shape=(1, 1, 512, 128) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/conv_1/batch_normalization/gamma:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/conv_1/batch_normalization/beta:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/conv_1/batch_normalization/moving_mean:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/conv_1/batch_normalization/moving_variance:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/deconv/deconv/weights:0' shape=(3, 3, 128, 128) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/deconv/batch_normalization/gamma:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/deconv/batch_normalization/beta:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/deconv/batch_normalization/moving_mean:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/deconv/batch_normalization/moving_variance:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/conv_2/conv2d/kernel:0' shape=(1, 1, 128, 256) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/conv_2/batch_normalization/gamma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/conv_2/batch_normalization/beta:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/conv_2/batch_normalization/moving_mean:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_4/conv_2/batch_normalization/moving_variance:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/conv_1/conv2d/kernel:0' shape=(1, 1, 256, 64) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/conv_1/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/conv_1/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/conv_1/batch_normalization/moving_mean:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/conv_1/batch_normalization/moving_variance:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/deconv/deconv/weights:0' shape=(3, 3, 64, 64) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/deconv/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/deconv/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/deconv/batch_normalization/moving_mean:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/deconv/batch_normalization/moving_variance:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/conv_2/conv2d/kernel:0' shape=(1, 1, 64, 128) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/conv_2/batch_normalization/gamma:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/conv_2/batch_normalization/beta:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/conv_2/batch_normalization/moving_mean:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_3/conv_2/batch_normalization/moving_variance:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/conv_1/conv2d/kernel:0' shape=(1, 1, 128, 32) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/conv_1/batch_normalization/gamma:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/conv_1/batch_normalization/beta:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/conv_1/batch_normalization/moving_mean:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/conv_1/batch_normalization/moving_variance:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/deconv/deconv/weights:0' shape=(3, 3, 32, 32) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/deconv/batch_normalization/gamma:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/deconv/batch_normalization/beta:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/deconv/batch_normalization/moving_mean:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/deconv/batch_normalization/moving_variance:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/conv_2/conv2d/kernel:0' shape=(1, 1, 32, 64) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/conv_2/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/conv_2/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/conv_2/batch_normalization/moving_mean:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_2/conv_2/batch_normalization/moving_variance:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/conv_1/conv2d/kernel:0' shape=(1, 1, 64, 16) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/conv_1/batch_normalization/gamma:0' shape=(16,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/conv_1/batch_normalization/beta:0' shape=(16,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/conv_1/batch_normalization/moving_mean:0' shape=(16,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/conv_1/batch_normalization/moving_variance:0' shape=(16,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/deconv/deconv/weights:0' shape=(3, 3, 16, 16) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/deconv/batch_normalization/gamma:0' shape=(16,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/deconv/batch_normalization/beta:0' shape=(16,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/deconv/batch_normalization/moving_mean:0' shape=(16,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/deconv/batch_normalization/moving_variance:0' shape=(16,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/conv_2/conv2d/kernel:0' shape=(1, 1, 16, 64) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/conv_2/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/conv_2/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/conv_2/batch_normalization/moving_mean:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/decoder_block_1/conv_2/batch_normalization/moving_variance:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'network/output_block/deconv_out_1/weights:0' shape=(3, 3, 32, 64) dtype=float32_ref>
<tf.Variable 'network/output_block/batch_normalization/gamma:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/output_block/batch_normalization/beta:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/output_block/batch_normalization/moving_mean:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/output_block/batch_normalization/moving_variance:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/output_block/conv2d/kernel:0' shape=(3, 3, 32, 32) dtype=float32_ref>
<tf.Variable 'network/output_block/batch_normalization_1/gamma:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/output_block/batch_normalization_1/beta:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/output_block/batch_normalization_1/moving_mean:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/output_block/batch_normalization_1/moving_variance:0' shape=(32,) dtype=float32_ref>
<tf.Variable 'network/output_block/deconv_out_2/weights:0' shape=(2, 2, 20, 32) dtype=float32_ref>
58 changes: 49 additions & 9 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import os
import pdb
import pickle
from utils.misc import calculate_flops

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
Expand Down Expand Up @@ -38,12 +40,13 @@ def __init__(self, args):

@timeit
def build_model(self):
print('Building Train Network')
with tf.variable_scope('network') as scope:
self.train_model = self.model(self.args, phase=0)
self.train_model.build()

if self.mode == 'train': # validation phase
if self.mode == 'train' or self.mode == 'overfit': # validation phase
print('Building Train Network')
with tf.variable_scope('network') as scope:
self.train_model = self.model(self.args, phase=0)
self.train_model.build()

print('Building Test Network')
with tf.variable_scope('network') as scope:
scope.reuse_variables()
Expand All @@ -52,7 +55,7 @@ def build_model(self):
else: # inference phase
print('Building Test Network')
with tf.variable_scope('network') as scope:
scope.reuse_variables()
self.train_model = None
self.test_model = self.model(self.args, phase=2)
self.test_model.build()

Expand All @@ -75,6 +78,7 @@ def run(self):
# Create Model class and build it
with self.sess.as_default():
self.build_model()

# Create the operator
self.operator = self.operator(self.args, self.sess, self.train_model, self.test_model)

Expand All @@ -89,22 +93,51 @@ def run(self):
self.overfit()
elif self.mode == 'inference':
self.inference()
else:
elif self.mode == 'inference_pkl':
self.load_pretrained_weights(self.sess, 'pretrained_weights/linknet_weights.pkl')
self.test(pkl=True)
elif self.mode == 'debug':
self.debug()
elif self.mode == 'test':
self.test()
else:
print("This mode {{{}}} is not found in our framework".format(self.mode))
exit(-1)

self.sess.close()
print("\nAgent is exited...\n")

def load_pretrained_weights(self, sess, pretrained_path):
print('############### START Loading from PKL ##################')
with open(pretrained_path, 'rb') as ff:
pretrained_weights = pickle.load(ff, encoding='latin1')

print("Loading pretrained weights of resnet18")
# all_vars = tf.trainable_variables()
# all_vars += tf.get_collection('mu_sigma_bn')
all_vars = tf.all_variables()
for v in all_vars:
if v.op.name in pretrained_weights.keys():
if str(v.shape) != str(pretrained_weights[v.op.name].shape):
print(v.shape)
print(pretrained_weights[v.op.name].shape)
print("Oh goooddd!!!")
exit(0)
assign_op = v.assign(pretrained_weights[v.op.name])
sess.run(assign_op)
print(v.op.name + " - loaded successfully, size ", pretrained_weights[v.op.name].shape)
print("All pretrained weights of resnet18 is loaded")

def train(self):
try:
self.operator.train()
self.operator.finalize()
except KeyboardInterrupt:
self.operator.finalize()

def test(self):
def test(self, pkl=False):
try:
self.operator.test()
self.operator.test(pkl)
except KeyboardInterrupt:
pass

Expand All @@ -120,3 +153,10 @@ def inference(self):
self.operator.test_inference()
except KeyboardInterrupt:
pass

def debug(self):
self.load_pretrained_weights(self.sess, 'pretrained_weights/linknet_weights.pkl')
try:
self.operator.debug_layers()
except KeyboardInterrupt:
pass
Loading

0 comments on commit 38ab3c1

Please sign in to comment.