From 16561a2e6ba62e714b7d2cbbb538ae9d90afd9a5 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 6 May 2020 00:33:53 -0700 Subject: [PATCH] [Test] Add tests for TensorFlow (#1501) * add test. * move test code. * remvoe unnecessary test. * fix. * turn on tests for TF. * Revert "move test code." This reverts commit e7b4f36395b2121a7be030bd4364a704d0e357bf. * fix. * fix. * skip test for tensorflow. Co-authored-by: Chao Ma --- python/dgl/distributed/partition.py | 10 ++++++---- tests/distributed/test_partition.py | 2 +- tests/distributed/test_shared_mem_store.py | 5 +++++ tests/scripts/task_unit_test.sh | 2 +- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index f3c40810e8a2..16cc7c244f33 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -140,9 +140,11 @@ def load_partition(conf_file, part_id): # TODO we need to fix this. DGL backend doesn't support boolean or byte. # int64 is unnecessary. - part_ids = F.zerocopy_from_numpy(node_map)[graph.ndata[NID]] + node_map = F.zerocopy_from_numpy(node_map) + part_ids = F.gather_row(node_map, graph.ndata[NID]) graph.ndata['local_node'] = F.astype(part_ids == part_id, F.int64) - part_ids = F.zerocopy_from_numpy(edge_map)[graph.edata[EID]] + edge_map = F.zerocopy_from_numpy(edge_map) + part_ids = F.gather_row(edge_map, graph.edata[EID]) graph.edata['local_edge'] = F.astype(part_ids == part_id, F.int64) return graph, node_feats, edge_feats, meta @@ -252,9 +254,9 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= len(local_nodes), len(local_edges))) tot_num_inner_edges += len(local_edges) for name in g.ndata: - node_feats[name] = g.ndata[name][local_nodes] + node_feats[name] = F.gather_row(g.ndata[name], local_nodes) for name in g.edata: - edge_feats[name] = g.edata[name][local_edges] + edge_feats[name] = F.gather_row(g.edata[name], local_edges) else: for name in g.ndata: node_feats[name] = g.ndata[name] diff --git a/tests/distributed/test_partition.py b/tests/distributed/test_partition.py index 424d08428b9c..253f75fc853f 100644 --- a/tests/distributed/test_partition.py +++ b/tests/distributed/test_partition.py @@ -47,7 +47,7 @@ def test_partition(): assert name in node_feats assert node_feats[name].shape[0] == len(local_nodes) assert len(local_nodes) == len(node_feats[name]) - assert np.all(F.asnumpy(g.ndata[name][local_nodes]) == F.asnumpy(node_feats[name])) + assert np.all(F.asnumpy(g.ndata[name])[local_nodes] == F.asnumpy(node_feats[name])) assert len(edge_feats) == 0 diff --git a/tests/distributed/test_shared_mem_store.py b/tests/distributed/test_shared_mem_store.py index 415d10113b8e..dac13ea9171c 100644 --- a/tests/distributed/test_shared_mem_store.py +++ b/tests/distributed/test_shared_mem_store.py @@ -5,6 +5,7 @@ """ import dgl import sys +import os import random import time import numpy as np @@ -101,6 +102,7 @@ def server_func(num_workers, graph_name, server_init): server_init.value = 1 g.run() +@unittest.skipIf(os.getenv('DGLBACKEND') == 'tensorflow', reason="skip for tensorflow") def test_init(): manager = Manager() return_dict = manager.dict() @@ -168,6 +170,7 @@ def check_compute_func(worker_id, graph_name, return_dict): print(e, file=sys.stderr) traceback.print_exc() +@unittest.skipIf(os.getenv('DGLBACKEND') == 'tensorflow', reason="skip for tensorflow") def test_compute(): manager = Manager() return_dict = manager.dict() @@ -215,6 +218,7 @@ def check_sync_barrier(worker_id, graph_name, return_dict): print(e, file=sys.stderr) traceback.print_exc() +@unittest.skipIf(os.getenv('DGLBACKEND') == 'tensorflow', reason="skip for tensorflow") def test_sync_barrier(): manager = Manager() return_dict = manager.dict() @@ -275,6 +279,7 @@ def check_mem(gidx, cond_v, shared_v): cond_v.notify() cond_v.release() +@unittest.skipIf(os.getenv('DGLBACKEND') == 'tensorflow', reason="skip for tensorflow") def test_copy_shared_mem(): csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64) gidx = dgl.graph_index.create_graph_index(csr, True) diff --git a/tests/scripts/task_unit_test.sh b/tests/scripts/task_unit_test.sh index a341832c27e7..8cab53a47a87 100644 --- a/tests/scripts/task_unit_test.sh +++ b/tests/scripts/task_unit_test.sh @@ -37,6 +37,6 @@ python3 -m pytest -v --junitxml=pytest_gindex.xml tests/graph_index || fail "gra python3 -m pytest -v --junitxml=pytest_backend.xml tests/$DGLBACKEND || fail "backend-specific" export OMP_NUM_THREADS=1 -if [ $2 != "gpu" ] && [ $1 != "tensorflow" ]; then +if [ $2 != "gpu" ]; then python3 -m pytest -v --junitxml=pytest_distributed.xml tests/distributed || fail "distributed" fi