@@ -78,13 +78,14 @@ def _build_loss(self):
78
78
with tf .variable_scope ("refiner" ):
79
79
self .realism_loss = tf .reduce_sum (
80
80
SE_loss (self .D_R_x_logits , real_label ), [1 , 2 ], name = "realism_loss" )
81
- self .regularization_loss = 0
82
- # self.reg_scale * tf.reduce_sum(
83
- # self.R_x - self.normalized_x, [1, 2, 3],
84
- # name="regularization_loss")
81
+ #self.regularization_loss = 0
82
+ self .regularization_loss = \
83
+ self .reg_scale * tf .reduce_sum (
84
+ tf .abs (self .R_x - self .normalized_x ), [1 , 2 , 3 ],
85
+ name = "regularization_loss" )
85
86
86
87
self .refiner_loss = tf .reduce_mean (
87
- self .realism_loss , # + self.regularization_loss,
88
+ self .realism_loss + self .regularization_loss ,
88
89
name = "refiner_loss" )
89
90
90
91
if self .debug :
@@ -162,22 +163,24 @@ def test_refiner(sess, inputs, summary_writer=None, with_output=False):
162
163
self .refiner_summary , summary_writer ,
163
164
output_op = self .R_x if with_output else None )
164
165
165
- def train_discrim (sess , inputs , summary_writer = None ):
166
+ def train_discrim (sess , inputs , summary_writer = None , with_output = False ):
166
167
fetch = {
167
168
'loss' : self .discrim_loss ,
168
169
'optim' : self .discrim_optim ,
169
170
'step' : self .discrim_step ,
170
171
}
171
172
return run (sess , inputs , fetch , self .x ,
172
- self .discrim_summary , summary_writer )
173
+ self .discrim_summary , summary_writer ,
174
+ output_op = self .D_R_x if with_output else None )
173
175
174
- def test_discrim (sess , inputs , summary_writer = None ):
176
+ def test_discrim (sess , inputs , summary_writer = None , with_output = False ):
175
177
fetch = {
176
178
'loss' : self .discrim_loss ,
177
179
'step' : self .discrim_step ,
178
180
}
179
181
return run (sess , inputs , fetch , self .x ,
180
- self .discrim_summary , summary_writer = summary_writer )
182
+ self .discrim_summary , summary_writer ,
183
+ output_op = self .D_R_x if with_output else None )
181
184
182
185
self .train_refiner = train_refiner
183
186
self .test_refiner = test_refiner
0 commit comments