Skip to content

Commit ee1d46b

Browse files
committedJan 3, 2017
now need to add debug for discriminator and add history
1 parent 9eb6bd2 commit ee1d46b

File tree

4 files changed

+20
-18
lines changed

4 files changed

+20
-18
lines changed
 

‎config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def add_argument_group(name):
3636
train_arg.add_argument('--is_train', type=str2bool, default=True, help='')
3737
train_arg.add_argument('--optimizer', type=str, default='rmsprop', help='')
3838
train_arg.add_argument('--max_step', type=int, default=200, help='')
39-
train_arg.add_argument('--reg_scale', type=float, default=0.01, help='')
39+
train_arg.add_argument('--reg_scale', type=float, default=1, help='')
4040
train_arg.add_argument('--initial_K_d', type=int, default=200, help='')
4141
train_arg.add_argument('--initial_K_g', type=int, default=1000, help='')
4242
train_arg.add_argument('--K_d', type=int, default=1, help='')

‎layers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def conv2d(inputs, num_outputs, kernel_size, stride,
4343
layer_dict={}, activation_fn=None,
4444
weights_initializer=tf.random_normal_initializer(0, 0.001),
4545
scope=None, name="", reuse=False, **kargv):
46-
print tf.random_normal_initializer(0, 0.001)
46+
print weights_initializer
4747
if True:
4848
outputs = slim.conv2d(
4949
inputs, num_outputs, kernel_size,

‎model.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,14 @@ def _build_loss(self):
7878
with tf.variable_scope("refiner"):
7979
self.realism_loss = tf.reduce_sum(
8080
SE_loss(self.D_R_x_logits, real_label), [1, 2], name="realism_loss")
81-
self.regularization_loss = 0
82-
# self.reg_scale * tf.reduce_sum(
83-
# self.R_x - self.normalized_x, [1, 2, 3],
84-
# name="regularization_loss")
81+
#self.regularization_loss = 0
82+
self.regularization_loss = \
83+
self.reg_scale * tf.reduce_sum(
84+
tf.abs(self.R_x - self.normalized_x), [1, 2, 3],
85+
name="regularization_loss")
8586

8687
self.refiner_loss = tf.reduce_mean(
87-
self.realism_loss, #+ self.regularization_loss,
88+
self.realism_loss + self.regularization_loss,
8889
name="refiner_loss")
8990

9091
if self.debug:
@@ -162,22 +163,24 @@ def test_refiner(sess, inputs, summary_writer=None, with_output=False):
162163
self.refiner_summary, summary_writer,
163164
output_op=self.R_x if with_output else None)
164165

165-
def train_discrim(sess, inputs, summary_writer=None):
166+
def train_discrim(sess, inputs, summary_writer=None, with_output=False):
166167
fetch = {
167168
'loss': self.discrim_loss,
168169
'optim': self.discrim_optim,
169170
'step': self.discrim_step,
170171
}
171172
return run(sess, inputs, fetch, self.x,
172-
self.discrim_summary, summary_writer)
173+
self.discrim_summary, summary_writer,
174+
output_op=self.D_R_x if with_output else None)
173175

174-
def test_discrim(sess, inputs, summary_writer=None):
176+
def test_discrim(sess, inputs, summary_writer=None, with_output=False):
175177
fetch = {
176178
'loss': self.discrim_loss,
177179
'step': self.discrim_step,
178180
}
179181
return run(sess, inputs, fetch, self.x,
180-
self.discrim_summary, summary_writer=summary_writer)
182+
self.discrim_summary, summary_writer,
183+
output_op=self.D_R_x if with_output else None)
181184

182185
self.train_refiner = train_refiner
183186
self.test_refiner = test_refiner

‎trainer.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,19 @@ def train(self):
5757
print("[*] Training starts...")
5858
summary_writer = None
5959

60+
for k in trange(self.initial_K_d, desc="Train discrim"):
61+
res = self.model.train_discrim(sess, self.data_loader.next(),
62+
summary_writer, with_output=False)
63+
summary_writer = self._get_summary_writer(res)
64+
6065
for k in trange(self.initial_K_g, desc="Train refiner"):
6166
data = self.data_loader.next()
6267
res = self.model.train_refiner(sess, data,
6368
summary_writer, with_output=True)
6469
summary_writer = self._get_summary_writer(res)
65-
import ipdb; ipdb.set_trace()
70+
# import ipdb; ipdb.set_trace()
6671
# self.model.R_x.eval({self.model.x: data)}, session=sess)
6772
# self.model.layer_dict['refiner/resnet/resnet_5'].eval({self.model.x: data},session=sess).max()
68-
x = 123
69-
70-
for k in trange(self.initial_K_d, desc="Train discrim"):
71-
res = self.model.train_discrim(sess, self.data_loader.next(),
72-
summary_writer, with_output=True)
73-
summary_writer = self._get_summary_writer(res)
7473

7574
for step in trange(self.max_step, desc="Train both"):
7675
for k in xrange(self.K_g):

0 commit comments

Comments
 (0)
Please sign in to comment.