-
Notifications
You must be signed in to change notification settings - Fork 944
/
Copy pathtest_reservation.py
132 lines (106 loc) · 4.23 KB
/
test_reservation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import threading
import time
import unittest
from tensorflowonspark import util
from tensorflowonspark.reservation import Reservations, Server, Client
from unittest import mock
class ReservationTest(unittest.TestCase):
def test_reservation_class(self):
"""Test core reservation class, expecting 2 reservations"""
r = Reservations(2)
self.assertFalse(r.done())
# add first reservation
r.add({'node': 1})
self.assertFalse(r.done())
self.assertEqual(r.remaining(), 1)
# add second reservation
r.add({'node': 2})
self.assertTrue(r.done())
self.assertEqual(r.remaining(), 0)
# get final list
reservations = r.get()
self.assertEqual(len(reservations), 2)
def test_reservation_server(self):
"""Test reservation server, expecting 1 reservation"""
s = Server(1)
addr = s.start()
# add first reservation
c = Client(addr)
resp = c.register({'node': 1})
self.assertEqual(resp, 'OK')
# get list of reservations
reservations = c.get_reservations()
self.assertEqual(len(reservations), 1)
# should return immediately with list of reservations
reservations = c.await_reservations()
self.assertEqual(len(reservations), 1)
# request server stop
c.request_stop()
time.sleep(1)
self.assertEqual(s.done, True)
def test_reservation_environment_exists_get_server_ip_return_environment_value(self):
tfos_server = Server(5)
with mock.patch.dict(os.environ, {'TFOS_SERVER_HOST': 'my_host_ip'}):
assert tfos_server.get_server_ip() == "my_host_ip"
def test_reservation_environment_not_exists_get_server_ip_return_actual_host_ip(self):
tfos_server = Server(5)
assert tfos_server.get_server_ip() == util.get_ip_address()
def test_reservation_environment_exists_start_listening_socket_return_socket_listening_to_environment_port_value(self):
tfos_server = Server(1)
with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9999'}):
assert tfos_server.start_listening_socket().getsockname()[1] == 9999
def test_reservation_environment_not_exists_start_listening_socket_return_socket(self):
tfos_server = Server(1)
print(tfos_server.start_listening_socket().getsockname()[1])
assert type(tfos_server.start_listening_socket().getsockname()[1]) == int
def test_reservation_environment_exists_port_spec(self):
tfos_server = Server(1)
with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9999'}):
self.assertEqual(tfos_server.get_server_ports(), [9999])
with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9997-9999'}):
self.assertEqual(tfos_server.get_server_ports(), [9997, 9998, 9999])
def test_reservation_environment_exists_start_listening_socket_return_socket_listening_to_environment_port_range(self):
tfos_server1 = Server(1)
tfos_server2 = Server(1)
tfos_server3 = Server(1)
with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9998-9999'}):
s1 = tfos_server1.start_listening_socket()
self.assertEqual(s1.getsockname()[1], 9998)
s2 = tfos_server2.start_listening_socket()
self.assertEqual(s2.getsockname()[1], 9999)
with self.assertRaises(Exception):
tfos_server3.start_listening_socket()
tfos_server1.stop()
tfos_server2.stop()
def test_reservation_server_multi(self):
"""Test reservation server, expecting multiple reservations"""
num_clients = 4
s = Server(num_clients)
addr = s.start()
def reserve(num):
c = Client(addr)
# time.sleep(random.randint(0,5)) # simulate varying start times
resp = c.register({'node': num})
self.assertEqual(resp, 'OK')
c.await_reservations()
c.close()
# start/register clients
threads = [None] * num_clients
for i in range(num_clients):
threads[i] = threading.Thread(target=reserve, args=(i,))
threads[i].start()
# wait for clients to complete
for i in range(num_clients):
threads[i].join()
print("all done")
# get list of reservations
c = Client(addr)
reservations = c.get_reservations()
self.assertEqual(len(reservations), num_clients)
# request server stop
c.request_stop()
time.sleep(1)
self.assertEqual(s.done, True)
if __name__ == '__main__':
unittest.main()