Skip to content

Commit

Permalink
Test shared-mem on kvstore (dmlc#976)
Browse files Browse the repository at this point in the history
  • Loading branch information
aksnzhy authored Nov 5, 2019
1 parent 7897fa3 commit 0b4935d
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions tests/compute/test_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@

server_namebook = { 0:'127.0.0.1:50062' }

def start_server():
def start_server(server_embed):
server = dgl.contrib.KVServer(
server_id=0,
client_namebook=client_namebook,
server_addr=server_namebook[0])

server.init_data(name='server_embed', data_tensor=server_embed)

server.start()

def start_client():
def start_client(server_embed):
client = dgl.contrib.KVClient(
client_id=0,
server_namebook=server_namebook,
Expand All @@ -38,6 +40,7 @@ def start_client():
for i in range(5):
client.push(name='embed_0', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_0)
client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push(name='server_embed', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)

client.barrier()

Expand Down Expand Up @@ -65,23 +68,40 @@ def start_client():
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))

_, tensor_0 = client.pull_wait()
_, tensor_1 = client.pull_wait()
_, tensor_2 = client.pull_wait()
_, tensor_3 = client.pull_wait()
_, tensor_4 = client.pull_wait()

target_tensor_2 = th.tensor([ 2., 2., 7., 2., 12.])

assert th.equal(tensor_0, target_tensor_0) == True
assert th.equal(tensor_1, target_tensor_1) == True
assert th.equal(tensor_2, target_tensor_0) == True
assert th.equal(tensor_3, target_tensor_1) == True
assert th.equal(tensor_4, target_tensor_2) == True

server_embed += target_tensor_2

assert th.equal(new_tensor_0, target_tensor_0) == True
assert th.equal(new_tensor_1, target_tensor_1) == True
assert th.equal(new_tensor_2, target_tensor_0) == True
assert th.equal(new_tensor_3, target_tensor_1) == True
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
_, tensor_5 = client.pull_wait()

assert th.equal(tensor_5, target_tensor_2 * 2) == True

client.shut_down()

if __name__ == '__main__':
server_embed = th.tensor([2., 2., 2., 2., 2.])
server_embed.share_memory_()

pid = os.fork()
if pid == 0:
start_server()
start_server(server_embed)
else:
time.sleep(2) # wait server start
start_client()
start_client(server_embed)

assert th.equal(server_embed, th.tensor([ 4., 4., 14., 4., 24.])) == True

0 comments on commit 0b4935d

Please sign in to comment.