Skip to content

Commit

Permalink
Fix bug in tensorflow tests. (ray-project#218)
Browse files Browse the repository at this point in the history
* Fix bug in tensorflow tests.

* Address comment.
  • Loading branch information
robertnishihara authored and pcmoritz committed Jan 20, 2017
1 parent 9bb8162 commit 7151ed5
Showing 1 changed file with 26 additions and 43 deletions.
69 changes: 26 additions & 43 deletions test/tensorflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def testVariableNameCollision(self):

ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)

net_vars1, init1, sess1 = ray.env.net1
net_vars2, init2, sess2 = ray.env.net2

Expand All @@ -108,7 +108,7 @@ def testNetworksIndependent(self):

ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)

net_vars1, init1, sess1 = ray.env.net1
net_vars2, init2, sess2 = ray.env.net2

Expand All @@ -117,41 +117,32 @@ def testNetworksIndependent(self):
sess2.run(init2)

@ray.remote
def get_vars1():
return ray.env.net1[0].get_weights()

@ray.remote
def get_vars2():
return ray.env.net2[0].get_weights()

@ray.remote
def set_vars1(weights):
ray.env.net1[0].set_weights(weights)

@ray.remote
def set_vars2(weights):
ray.env.net2[0].set_weights(weights)

# Get the weights.
def set_and_get_weights(weights1, weights2):
ray.env.net1[0].set_weights(weights1)
ray.env.net2[0].set_weights(weights2)
return ray.env.net1[0].get_weights(), ray.env.net2[0].get_weights()

# Make sure the two networks have different weights. TODO(rkn): Note that
# equality comparisons of numpy arrays normally does not work. This only
# works because at the moment they have size 1.
weights1 = net_vars1.get_weights()
weights2 = net_vars2.get_weights()
self.assertNotEqual(weights1, weights2)

# Swap the weights.
set_vars2.remote(weights1)
set_vars1.remote(weights2)

# Get the new weights.
new_weights1 = ray.get(get_vars1.remote())
new_weights2 = ray.get(get_vars2.remote())
self.assertNotEqual(new_weights1, new_weights2)
# Set the weights and get the weights, and make sure they are unchanged.
new_weights1, new_weights2 = ray.get(set_and_get_weights.remote(weights1, weights2))
self.assertEqual(weights1, new_weights1)
self.assertEqual(weights2, new_weights2)

# Check that the weights were swapped.
self.assertEqual(weights1, new_weights2)
self.assertEqual(weights2, new_weights1)
# Swap the weights.
new_weights2, new_weights1 = ray.get(set_and_get_weights.remote(weights2, weights1))
self.assertEqual(weights1, new_weights1)
self.assertEqual(weights2, new_weights2)

ray.worker.cleanup()

# This test creates an additional network on the driver so that the tensorflow
# variables on the driver and the worker differ.
def testNetworkDriverWorkerIndependent(self):
ray.init(num_workers=1)

Expand All @@ -167,23 +158,15 @@ def testNetworkDriverWorkerIndependent(self):
net_vars2, init2, sess2 = ray.env.net
sess2.run(init2)

# Get the weights.
weights1 = net_vars1.get_weights()
weights2 = net_vars2.get_weights()
self.assertNotEqual(weights1, weights2)

# Swap the weights.
net_vars1.set_weights(weights2)
net_vars2.set_weights(weights1)

# Get the new weights.
new_weights1 = net_vars1.get_weights()
new_weights2 = net_vars2.get_weights()
self.assertNotEqual(new_weights1, new_weights2)
@ray.remote
def set_and_get_weights(weights):
ray.env.net[0].set_weights(weights)
return ray.env.net[0].get_weights()

# Check that the weights were swapped.
self.assertEqual(weights1, new_weights2)
self.assertEqual(weights2, new_weights1)
new_weights2 = ray.get(set_and_get_weights.remote(net_vars2.get_weights()))
self.assertEqual(weights2, new_weights2)

ray.worker.cleanup()

Expand Down

0 comments on commit 7151ed5

Please sign in to comment.