Skip to content

Commit

Permalink
[Bug fix] Fix concurrency bug reported at issue#755 (dmlc#823)
Browse files Browse the repository at this point in the history
* upd

* fig edgebatch edges

* add test

* trigger

* Update README.md for pytorch PinSage example.

Add noting that the PinSage model example under
example/pytorch/recommendation only work with Python 3.6+
as its dataset loader depends on stanfordnlp package
which work only with Python 3.6+.

* Provid a frame agnostic API to test nn modules on both CPU and CUDA side.

1. make dgl.nn.xxx frame agnostic
2. make test.backend include dgl.nn modules
3. modify test_edge_softmax of test/mxnet/test_nn.py and
    test/pytorch/test_nn.py work on both CPU and GPU

* Fix style

* Delete unused code

* Make agnostic test only related to tests/backend

1. clear all agnostic related code in dgl.nn
2. make test_graph_conv agnostic to cpu/gpu

* Fix code style

* fix

* doc

* Make all test code under tests.mxnet/pytorch.test_nn.py
work on both CPU and GPU.

* Fix syntex

* Remove rand

* Add TAGCN nn.module and example

* Now tagcn can run on CPU.

* Add unitest for TGConv

* Fix style

* For pubmed dataset, using --lr=0.005 can achieve better acc

* Fix style

* Fix some descriptions

* trigger

* Fix doc

* Add nn.TGConv and example

* Fix bug

* Update data in mxnet.tagcn test acc.

* Fix some comments and code

* delete useless code

* Fix namming

* Fix bug

* Fix bug

* Add test for mxnet TAGCov

* Add test code for mxnet TAGCov

* Update some docs

* Fix some code

* Update docs dgl.nn.mxnet

* Update weight init

* Fix

* reproduce the bug

* Fix concurrency bug reported at dmlc#755.
Also make test_shared_mem_store.py more deterministic.

* Update test_shared_mem_store.py

* Update dmlc/core
  • Loading branch information
classicsong authored and zheng-da committed Oct 8, 2019
1 parent 5e17ef5 commit fd1b147
Showing 1 changed file with 56 additions and 28 deletions.
84 changes: 56 additions & 28 deletions tests/distributed/test_shared_mem_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import numpy as np
from numpy.testing import assert_array_equal
from multiprocessing import Process, Manager
from multiprocessing import Process, Manager, Condition, Value
from scipy import sparse as spsp
import backend as F
import unittest
Expand Down Expand Up @@ -45,8 +45,6 @@ def create_graph_store(graph_name):
return None

def check_init_func(worker_id, graph_name, return_dict):
time.sleep(3)
print("worker starts")
np.random.seed(0)
csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)

Expand All @@ -69,6 +67,7 @@ def check_init_func(worker_id, graph_name, return_dict):
g.init_edata('test4', (g.number_of_edges(), 10), 'float32')
g._sync_barrier(60)
check_array_shared_memory(g, worker_id, [g.nodes[:].data['test4'], g.edges[:].data['test4']])
g._sync_barrier(60)

data = g.nodes[:].data['test4']
g.set_n_repr({'test4': F.ones((1, 10)) * 10}, u=[0])
Expand All @@ -86,8 +85,7 @@ def check_init_func(worker_id, graph_name, return_dict):
print(e, file=sys.stderr)
traceback.print_exc()

def server_func(num_workers, graph_name):
print("server starts")
def server_func(num_workers, graph_name, server_init):
np.random.seed(0)
csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)

Expand All @@ -99,16 +97,21 @@ def server_func(num_workers, graph_name):
efeat = np.arange(0, num_edges * 10).astype('float32').reshape((num_edges, 10))
g.ndata['feat'] = F.tensor(nfeat)
g.edata['feat'] = F.tensor(efeat)
server_init.value = 1
g.run()

@unittest.skip
def test_init():
manager = Manager()
return_dict = manager.dict()
serv_p = Process(target=server_func, args=(2, 'test_graph1'))

# make server init before worker
server_init = Value('i', False)
serv_p = Process(target=server_func, args=(2, 'test_graph1', server_init))
serv_p.start()
while server_init.value == 0:
time.sleep(1)
work_p1 = Process(target=check_init_func, args=(0, 'test_graph1', return_dict))
work_p2 = Process(target=check_init_func, args=(1, 'test_graph1', return_dict))
serv_p.start()
work_p1.start()
work_p2.start()
serv_p.join()
Expand All @@ -117,10 +120,7 @@ def test_init():
for worker_id in return_dict.keys():
assert return_dict[worker_id] == 0, "worker %d fails" % worker_id


def check_compute_func(worker_id, graph_name, return_dict):
time.sleep(3)
print("worker starts")
try:
g = create_graph_store(graph_name)
if g is None:
Expand All @@ -129,14 +129,14 @@ def check_compute_func(worker_id, graph_name, return_dict):

g._sync_barrier(60)
in_feats = g.nodes[0].data['feat'].shape[1]

# Test update all.
g.update_all(fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='preprocess'))
adj = g.adjacency_matrix()
tmp = F.spmm(adj, g.nodes[:].data['feat'])
assert_almost_equal(F.asnumpy(g.nodes[:].data['preprocess']), F.asnumpy(tmp))
g._sync_barrier(60)
check_array_shared_memory(g, worker_id, [g.nodes[:].data['preprocess']])
g._sync_barrier(60)

# Test apply nodes.
data = g.nodes[:].data['feat']
Expand Down Expand Up @@ -167,15 +167,18 @@ def check_compute_func(worker_id, graph_name, return_dict):
print(e, file=sys.stderr)
traceback.print_exc()


@unittest.skip
def test_compute():
manager = Manager()
return_dict = manager.dict()
serv_p = Process(target=server_func, args=(2, 'test_graph3'))

# make server init before worker
server_init = Value('i', 0)
serv_p = Process(target=server_func, args=(2, 'test_graph3', server_init))
serv_p.start()
while server_init.value == 0:
time.sleep(1)
work_p1 = Process(target=check_compute_func, args=(0, 'test_graph3', return_dict))
work_p2 = Process(target=check_compute_func, args=(1, 'test_graph3', return_dict))
serv_p.start()
work_p1.start()
work_p2.start()
serv_p.join()
Expand All @@ -185,8 +188,6 @@ def test_compute():
assert return_dict[worker_id] == 0, "worker %d fails" % worker_id

def check_sync_barrier(worker_id, graph_name, return_dict):
time.sleep(3)
print("worker starts")
try:
g = create_graph_store(graph_name)
if g is None:
Expand All @@ -213,14 +214,18 @@ def check_sync_barrier(worker_id, graph_name, return_dict):
print(e, file=sys.stderr)
traceback.print_exc()

@unittest.skip
def test_sync_barrier():
manager = Manager()
return_dict = manager.dict()
serv_p = Process(target=server_func, args=(2, 'test_graph4'))

# make server init before worker
server_init = Value('i', 0)
serv_p = Process(target=server_func, args=(2, 'test_graph4', server_init))
serv_p.start()
while server_init.value == 0:
time.sleep(1)
work_p1 = Process(target=check_sync_barrier, args=(0, 'test_graph4', return_dict))
work_p2 = Process(target=check_sync_barrier, args=(1, 'test_graph4', return_dict))
serv_p.start()
work_p1.start()
work_p2.start()
serv_p.join()
Expand All @@ -229,13 +234,28 @@ def test_sync_barrier():
for worker_id in return_dict.keys():
assert return_dict[worker_id] == 0, "worker %d fails" % worker_id

def create_mem(gidx):
def create_mem(gidx, cond_v, shared_v):
# serialize create_mem before check_mem
cond_v.acquire()
gidx1 = gidx.copyto_shared_mem("in", "test_graph5")
gidx2 = gidx.copyto_shared_mem("out", "test_graph6")
time.sleep(30)
shared_v.value = 1;
cond_v.notify()
cond_v.release()

# sync for exit
cond_v.acquire()
while shared_v.value == 1:
cond_v.wait()
cond_v.release()

def check_mem(gidx, cond_v, shared_v):
# check_mem should run after create_mem
cond_v.acquire()
while shared_v.value == 0:
cond_v.wait()
cond_v.release()

def check_mem(gidx):
time.sleep(10)
gidx1 = dgl.graph_index.from_shared_mem_csr_matrix("test_graph5", gidx.number_of_nodes(),
gidx.number_of_edges(), "in", False)
gidx2 = dgl.graph_index.from_shared_mem_csr_matrix("test_graph6", gidx.number_of_nodes(),
Expand All @@ -260,12 +280,20 @@ def check_mem(gidx):
gidx1 = gidx1.copyto_shared_mem("in", "test_graph5")
gidx2 = gidx2.copyto_shared_mem("out", "test_graph6")

@unittest.skip
#sync for exit
cond_v.acquire()
shared_v.value = 0;
cond_v.notify()
cond_v.release()

def test_copy_shared_mem():
csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)
gidx = dgl.graph_index.create_graph_index(csr, False, True)
p1 = Process(target=create_mem, args=(gidx,))
p2 = Process(target=check_mem, args=(gidx,))

cond_v = Condition()
shared_v = Value('i', 0)
p1 = Process(target=create_mem, args=(gidx, cond_v, shared_v))
p2 = Process(target=check_mem, args=(gidx, cond_v, shared_v))
p1.start()
p2.start()
p1.join()
Expand Down

0 comments on commit fd1b147

Please sign in to comment.