Skip to content

Commit

Permalink
Upgrade to TF 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
nealwu committed Apr 3, 2017
1 parent b8c29cd commit 40c3b10
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 18 deletions.
10 changes: 6 additions & 4 deletions domain_adaptation/domain_separation/dsn.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,17 @@ def concat_operation(shared_repr, private_repr):

# Add summaries
source_reconstructions = tf.concat(
map(normalize_images, [
axis=2,
values=map(normalize_images, [
source_data, source_recons, source_shared_recons,
source_private_recons
]), 2)
]))
target_reconstructions = tf.concat(
map(normalize_images, [
axis=2,
values=map(normalize_images, [
target_data, target_recons, target_shared_recons,
target_private_recons
]), 2)
]))
tf.summary.image(
'Source Images:Recons:RGB',
source_reconstructions[:, :, :, :3],
Expand Down
2 changes: 1 addition & 1 deletion domain_adaptation/domain_separation/dsn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def testBasicDomainSeparationStartPoint(self):
with self.test_session() as sess:
# Test for when global_step < domain_separation_startpoint
step = tf.contrib.slim.get_or_create_global_step()
sess.run(tf.initialize_all_variables()) # global_step = 0
sess.run(tf.global_variables_initializer()) # global_step = 0
params = {'domain_separation_startpoint': 2}
weight = dsn.dsn_loss_coefficient(params)
weight_np = sess.run(weight)
Expand Down
14 changes: 7 additions & 7 deletions domain_adaptation/domain_separation/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def mmd_loss(source_samples, target_samples, weight, scope=None):
tag = 'MMD Loss'
if scope:
tag = scope + tag
tf.contrib.deprecated.scalar_summary(tag, loss_value)
tf.summary.scalar(tag, loss_value)
tf.losses.add_loss(loss_value)

return loss_value
Expand Down Expand Up @@ -135,7 +135,7 @@ def correlation_loss(source_samples, target_samples, weight, scope=None):
tag = 'Correlation Loss'
if scope:
tag = scope + tag
tf.contrib.deprecated.scalar_summary(tag, corr_loss)
tf.summary.scalar(tag, corr_loss)
tf.losses.add_loss(corr_loss)

return corr_loss
Expand All @@ -155,11 +155,11 @@ def dann_loss(source_samples, target_samples, weight, scope=None):
"""
with tf.variable_scope('dann'):
batch_size = tf.shape(source_samples)[0]
samples = tf.concat([source_samples, target_samples], 0)
samples = tf.concat(axis=0, values=[source_samples, target_samples])
samples = slim.flatten(samples)

domain_selection_mask = tf.concat(
[tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], 0)
axis=0, values=[tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))])

# Perform the gradient reversal and be careful with the shape.
grl = grl_ops.gradient_reversal(samples)
Expand All @@ -184,9 +184,9 @@ def dann_loss(source_samples, target_samples, weight, scope=None):
tag_loss = scope + tag_loss
tag_accuracy = scope + tag_accuracy

tf.contrib.deprecated.scalar_summary(
tf.summary.scalar(
tag_loss, domain_loss, name='domain_loss_summary')
tf.contrib.deprecated.scalar_summary(
tf.summary.scalar(
tag_accuracy, domain_accuracy, name='domain_accuracy_summary')

return domain_loss
Expand Down Expand Up @@ -216,7 +216,7 @@ def difference_loss(private_samples, shared_samples, weight=1.0, name=''):
cost = tf.reduce_mean(tf.square(correlation_matrix)) * weight
cost = tf.where(cost > 0, cost, 0, name='value')

tf.contrib.deprecated.scalar_summary('losses/Difference Loss {}'.format(name),
tf.summary.scalar('losses/Difference Loss {}'.format(name),
cost)
assert_op = tf.Assert(tf.is_finite(cost), [cost])
with tf.control_dependencies([assert_op]):
Expand Down
2 changes: 1 addition & 1 deletion domain_adaptation/domain_separation/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _testDecoder(self,
width=width,
channels=channels,
batch_norm_params=batch_norm_params)
sess.run(tf.initialize_all_variables())
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
self.assertEqual(output_np.shape, (32, height, width, channels))
self.assertTrue(np.any(output_np))
Expand Down
10 changes: 5 additions & 5 deletions domain_adaptation/domain_separation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def reshape_feature_maps(features_tensor):
num_filters)
num_filters_sqrt = int(num_filters_sqrt)
conv_summary = tf.unstack(features_tensor, axis=3)
conv_one_row = tf.concat(conv_summary[0:num_filters_sqrt], 2)
conv_one_row = tf.concat(axis=2, values=conv_summary[0:num_filters_sqrt])
ind = 1
conv_final = conv_one_row
for ind in range(1, num_filters_sqrt):
conv_one_row = tf.concat(conv_summary[
ind * num_filters_sqrt + 0:ind * num_filters_sqrt + num_filters_sqrt],
2)
conv_one_row = tf.concat(axis=2,
values=conv_summary[
ind * num_filters_sqrt + 0:ind * num_filters_sqrt + num_filters_sqrt])
conv_final = tf.concat(
[tf.squeeze(conv_final), tf.squeeze(conv_one_row)], 1)
axis=1, values=[tf.squeeze(conv_final), tf.squeeze(conv_one_row)])
conv_final = tf.expand_dims(conv_final, -1)
return conv_final

Expand Down

0 comments on commit 40c3b10

Please sign in to comment.