Skip to content

Commit

Permalink
[Test] Add tests for TensorFlow (dmlc#1501)
Browse files Browse the repository at this point in the history
* add test.

* move test code.

* remvoe unnecessary test.

* fix.

* turn on tests for TF.

* Revert "move test code."

This reverts commit e7b4f36.

* fix.

* fix.

* skip test for tensorflow.

Co-authored-by: Chao Ma <[email protected]>
  • Loading branch information
zheng-da and aksnzhy authored May 6, 2020
1 parent 6ae440d commit 16561a2
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 6 deletions.
10 changes: 6 additions & 4 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,11 @@ def load_partition(conf_file, part_id):

# TODO we need to fix this. DGL backend doesn't support boolean or byte.
# int64 is unnecessary.
part_ids = F.zerocopy_from_numpy(node_map)[graph.ndata[NID]]
node_map = F.zerocopy_from_numpy(node_map)
part_ids = F.gather_row(node_map, graph.ndata[NID])
graph.ndata['local_node'] = F.astype(part_ids == part_id, F.int64)
part_ids = F.zerocopy_from_numpy(edge_map)[graph.edata[EID]]
edge_map = F.zerocopy_from_numpy(edge_map)
part_ids = F.gather_row(edge_map, graph.edata[EID])
graph.edata['local_edge'] = F.astype(part_ids == part_id, F.int64)

return graph, node_feats, edge_feats, meta
Expand Down Expand Up @@ -252,9 +254,9 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
len(local_nodes), len(local_edges)))
tot_num_inner_edges += len(local_edges)
for name in g.ndata:
node_feats[name] = g.ndata[name][local_nodes]
node_feats[name] = F.gather_row(g.ndata[name], local_nodes)
for name in g.edata:
edge_feats[name] = g.edata[name][local_edges]
edge_feats[name] = F.gather_row(g.edata[name], local_edges)
else:
for name in g.ndata:
node_feats[name] = g.ndata[name]
Expand Down
2 changes: 1 addition & 1 deletion tests/distributed/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_partition():
assert name in node_feats
assert node_feats[name].shape[0] == len(local_nodes)
assert len(local_nodes) == len(node_feats[name])
assert np.all(F.asnumpy(g.ndata[name][local_nodes]) == F.asnumpy(node_feats[name]))
assert np.all(F.asnumpy(g.ndata[name])[local_nodes] == F.asnumpy(node_feats[name]))
assert len(edge_feats) == 0


Expand Down
5 changes: 5 additions & 0 deletions tests/distributed/test_shared_mem_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import dgl
import sys
import os
import random
import time
import numpy as np
Expand Down Expand Up @@ -101,6 +102,7 @@ def server_func(num_workers, graph_name, server_init):
server_init.value = 1
g.run()

@unittest.skipIf(os.getenv('DGLBACKEND') == 'tensorflow', reason="skip for tensorflow")
def test_init():
manager = Manager()
return_dict = manager.dict()
Expand Down Expand Up @@ -168,6 +170,7 @@ def check_compute_func(worker_id, graph_name, return_dict):
print(e, file=sys.stderr)
traceback.print_exc()

@unittest.skipIf(os.getenv('DGLBACKEND') == 'tensorflow', reason="skip for tensorflow")
def test_compute():
manager = Manager()
return_dict = manager.dict()
Expand Down Expand Up @@ -215,6 +218,7 @@ def check_sync_barrier(worker_id, graph_name, return_dict):
print(e, file=sys.stderr)
traceback.print_exc()

@unittest.skipIf(os.getenv('DGLBACKEND') == 'tensorflow', reason="skip for tensorflow")
def test_sync_barrier():
manager = Manager()
return_dict = manager.dict()
Expand Down Expand Up @@ -275,6 +279,7 @@ def check_mem(gidx, cond_v, shared_v):
cond_v.notify()
cond_v.release()

@unittest.skipIf(os.getenv('DGLBACKEND') == 'tensorflow', reason="skip for tensorflow")
def test_copy_shared_mem():
csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)
gidx = dgl.graph_index.create_graph_index(csr, True)
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/task_unit_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ python3 -m pytest -v --junitxml=pytest_gindex.xml tests/graph_index || fail "gra
python3 -m pytest -v --junitxml=pytest_backend.xml tests/$DGLBACKEND || fail "backend-specific"

export OMP_NUM_THREADS=1
if [ $2 != "gpu" ] && [ $1 != "tensorflow" ]; then
if [ $2 != "gpu" ]; then
python3 -m pytest -v --junitxml=pytest_distributed.xml tests/distributed || fail "distributed"
fi

0 comments on commit 16561a2

Please sign in to comment.