-
Notifications
You must be signed in to change notification settings - Fork 944
/
Copy pathtest_TFCluster.py
125 lines (102 loc) · 4.76 KB
/
test_TFCluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import unittest
import test
import time
from tensorflowonspark import TFCluster, TFNode
class TFClusterTest(test.SparkTest):
@classmethod
def setUpClass(cls):
super(TFClusterTest, cls).setUpClass()
@classmethod
def tearDownClass(cls):
super(TFClusterTest, cls).tearDownClass()
def test_basic_tf(self):
"""Single-node TF graph (w/ args) running independently on multiple executors."""
def _map_fun(args, ctx):
import tensorflow as tf
x = tf.constant(args['x'])
y = tf.constant(args['y'])
sum = tf.math.add(x, y)
assert sum.numpy() == 3
args = {'x': 1, 'y': 2}
cluster = TFCluster.run(self.sc, _map_fun, tf_args=args, num_executors=self.num_workers, num_ps=0)
cluster.shutdown()
def test_inputmode_spark(self):
"""Distributed TF cluster w/ InputMode.SPARK"""
def _map_fun(args, ctx):
import tensorflow as tf
tf_feed = TFNode.DataFeed(ctx.mgr, False)
while not tf_feed.should_stop():
batch = tf_feed.next_batch(batch_size=10)
print("batch: {}".format(batch))
squares = tf.math.square(batch)
print("squares: {}".format(squares))
tf_feed.batch_results(squares.numpy())
input = [[x] for x in range(1000)] # set up input as tensors of shape [1] to match placeholder
rdd = self.sc.parallelize(input, 10)
cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.SPARK)
rdd_out = cluster.inference(rdd)
rdd_sum = rdd_out.sum()
self.assertEqual(rdd_sum, sum([x * x for x in range(1000)]))
cluster.shutdown()
def test_inputmode_spark_exception(self):
"""Distributed TF cluster w/ InputMode.SPARK and exception during feeding"""
def _map_fun(args, ctx):
import tensorflow as tf
tf_feed = TFNode.DataFeed(ctx.mgr, False)
while not tf_feed.should_stop():
batch = tf_feed.next_batch(10)
if len(batch) > 0:
squares = tf.math.square(batch)
tf_feed.batch_results(squares.numpy())
raise Exception("FAKE exception during feeding")
input = [[x] for x in range(1000)] # set up input as tensors of shape [1] to match placeholder
rdd = self.sc.parallelize(input, 10)
with self.assertRaises(Exception):
cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.SPARK)
cluster.inference(rdd, feed_timeout=1).count()
cluster.shutdown()
def test_inputmode_spark_late_exception(self):
"""Distributed TF cluster w/ InputMode.SPARK and exception after feeding"""
def _map_fun(args, ctx):
import tensorflow as tf
tf_feed = TFNode.DataFeed(ctx.mgr, False)
while not tf_feed.should_stop():
batch = tf_feed.next_batch(10)
if len(batch) > 0:
squares = tf.math.square(batch)
tf_feed.batch_results(squares.numpy())
# simulate post-feed actions that raise an exception
time.sleep(2)
raise Exception("FAKE exception after feeding")
input = [[x] for x in range(1000)] # set up input as tensors of shape [1] to match placeholder
rdd = self.sc.parallelize(input, 10)
with self.assertRaises(Exception):
cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.SPARK)
cluster.inference(rdd).count()
cluster.shutdown(grace_secs=5) # note: grace_secs must be larger than the time needed for post-feed actions
def test_port_released(self):
"""Test that temporary socket/port is released prior to invoking user map_fun."""
def _map_fun(args, ctx):
assert ctx.tmp_socket is None
cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief')
cluster.shutdown()
def test_port_unreleased(self):
"""Test that temporary socket/port is unreleased prior to invoking user map_fun."""
def _map_fun(args, ctx):
import socket
assert ctx.tmp_socket is not None
reserved_port = ctx.tmp_socket.getsockname()[1]
# socket bind to tmp port should fail
try:
my_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
my_sock.bind(('0.0.0.0', reserved_port))
assert False, "should never hit this assert statement"
except socket.error as e:
print(e)
assert True, "should raise an exception"
ctx.release_port()
assert ctx.tmp_socket is None
cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief', release_port=False)
cluster.shutdown()
if __name__ == '__main__':
unittest.main()