Skip to content

Commit

Permalink
[rpc] Move test_rpc.py to distributed and use dynamic port binding (d…
Browse files Browse the repository at this point in the history
…mlc#1623)

* update

* update

* update

* update

* update

* update

* update

* update
  • Loading branch information
aksnzhy authored Jun 11, 2020
1 parent 0c40860 commit 0c313e5
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 7 deletions.
39 changes: 35 additions & 4 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import numpy as np
import time
import socket
from scipy import sparse as spsp
from numpy.testing import assert_array_equal
from multiprocessing import Process, Manager, Condition, Value
Expand All @@ -16,6 +17,35 @@
import unittest
import pickle

if os.name != 'nt':
import fcntl
import struct

def get_local_usable_addr():
"""Get local usable IP and port
Returns
-------
str
IP address, e.g., '192.168.8.12:50051'
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
sock.connect(('10.255.255.255', 1))
ip_addr = sock.getsockname()[0]
except ValueError:
ip_addr = '127.0.0.1'
finally:
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
sock.listen(1)
port = sock.getsockname()[1]
sock.close()

return ip_addr + ' ' + str(port)

def create_random_graph(n):
arr = (spsp.random(n, n, density=0.001, format='coo') != 0).astype(np.int64)
ig = create_graph_index(arr, readonly=True)
Expand Down Expand Up @@ -95,7 +125,7 @@ def test_server_client():

# Partition the graph
num_parts = 1
graph_name = 'test'
graph_name = 'dist_graph_test'
g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, '/tmp')
Expand Down Expand Up @@ -126,14 +156,14 @@ def test_split():
g = create_random_graph(10000)
num_parts = 4
num_hops = 2
partition_graph(g, 'test', num_parts, '/tmp', num_hops=num_hops, part_method='metis')
partition_graph(g, 'dist_graph_test', num_parts, '/tmp', num_hops=num_hops, part_method='metis')

node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
selected_nodes = np.nonzero(node_mask)[0]
selected_edges = np.nonzero(edge_mask)[0]
for i in range(num_parts):
part_g, node_feats, edge_feats, meta = load_partition('/tmp/test.json', i)
part_g, node_feats, edge_feats, meta = load_partition('/tmp/dist_graph_test.json', i)
num_nodes, num_edges, node_map, edge_map, num_partitions = meta
gpb = GraphPartitionBook(part_id=i,
num_parts=num_partitions,
Expand All @@ -160,7 +190,8 @@ def test_split():

def prepare_dist():
ip_config = open("kv_ip_config.txt", "w")
ip_config.write('127.0.0.1 2500 1\n')
ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr)
ip_config.close()

if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tests/distributed/test_graph_partition_book.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def test_graph_partition_book():
num_parts = 4
num_hops = 2

partition_graph(g, 'test', num_parts, '/tmp', num_hops=num_hops, part_method='metis')
partition_graph(g, 'gpb_test', num_parts, '/tmp', num_hops=num_hops, part_method='metis')

for i in range(num_parts):
part_g, node_feats, edge_feats, meta = load_partition('/tmp/test.json', i)
part_g, node_feats, edge_feats, meta = load_partition('/tmp/gpb_test.json', i)
num_nodes, num_edges, node_map, edge_map, num_partitions = meta
gpb = GraphPartitionBook(part_id=i,
num_parts=num_partitions,
Expand Down
33 changes: 32 additions & 1 deletion tests/compute/test_rpc.py → tests/distributed/test_rpc.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,47 @@
import os
import time
import socket

import dgl
import backend as F
import unittest, pytest

from numpy.testing import assert_array_equal

if os.name != 'nt':
import fcntl
import struct

INTEGER = 2
STR = 'hello world!'
HELLO_SERVICE_ID = 901231
TENSOR = F.zeros((10, 10), F.int64, F.cpu())

def get_local_usable_addr():
"""Get local usable IP and port
Returns
-------
str
IP address, e.g., '192.168.8.12:50051'
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
sock.connect(('10.255.255.255', 1))
ip_addr = sock.getsockname()[0]
except ValueError:
ip_addr = '127.0.0.1'
finally:
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
sock.listen(1)
port = sock.getsockname()[1]
sock.close()

return ip_addr + ' ' + str(port)

def foo(x, y):
assert x == 123
assert y == "abc"
Expand Down Expand Up @@ -158,7 +188,8 @@ def test_rpc_msg():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_rpc():
ip_config = open("rpc_ip_config.txt", "w")
ip_config.write('127.0.0.1 30050 1\n')
ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr)
ip_config.close()
pid = os.fork()
if pid == 0:
Expand Down

0 comments on commit 0c313e5

Please sign in to comment.