Skip to content

Commit

Permalink
[DistGB] add graphbolt flag into top level API (dmlc#7122)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Feb 22, 2024
1 parent 7a10bcb commit b0080d5
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 43 deletions.
1 change: 1 addition & 0 deletions python/dgl/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .rpc_client import connect_to_server
from .rpc_server import start_server
from .server_state import ServerState
from .constants import *
3 changes: 3 additions & 0 deletions python/dgl/distributed/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@

DEFAULT_NTYPE = "_N"
DEFAULT_ETYPE = (DEFAULT_NTYPE, "_E", DEFAULT_NTYPE)

DATA_LOADING_BACKEND_DGL = "DGL"
DATA_LOADING_BACKEND_GRAPHBOLT = "GraphBolt"
11 changes: 10 additions & 1 deletion python/dgl/distributed/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from .. import utils
from ..base import dgl_warning, DGLError
from . import rpc
from .constants import MAX_QUEUE_SIZE
from .constants import (
DATA_LOADING_BACKEND_DGL,
DATA_LOADING_BACKEND_GRAPHBOLT,
MAX_QUEUE_SIZE,
)
from .kvstore import close_kvstore, init_kvstore
from .role import init_role
from .rpc_client import connect_to_server
Expand Down Expand Up @@ -210,6 +214,7 @@ def initialize(
max_queue_size=MAX_QUEUE_SIZE,
net_type=None,
num_worker_threads=1,
data_loading_backend=DATA_LOADING_BACKEND_DGL,
):
"""Initialize DGL's distributed module
Expand All @@ -231,6 +236,8 @@ def initialize(
[Deprecated] Networking type, can be 'socket' only.
num_worker_threads: int
The number of OMP threads in each sampler process.
data_loading_backend: str, optional
The backend for data loading. Can be 'DGL' or 'GraphBolt'.
Note
----
Expand Down Expand Up @@ -270,6 +277,8 @@ def initialize(
int(os.environ.get("DGL_NUM_CLIENT")),
os.environ.get("DGL_CONF_PATH"),
graph_format=formats,
use_graphbolt=data_loading_backend
== DATA_LOADING_BACKEND_GRAPHBOLT,
)
serv.start()
sys.exit()
Expand Down
70 changes: 53 additions & 17 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..ndarray import exist_shared_mem_array
from ..transforms import compact_graphs
from . import graph_services, role, rpc
from .constants import DATA_LOADING_BACKEND_DGL, DATA_LOADING_BACKEND_GRAPHBOLT
from .dist_tensor import DistTensor
from .graph_partition_book import (
_etype_str_to_tuple,
Expand Down Expand Up @@ -50,6 +51,7 @@
)

INIT_GRAPH = 800001
QUERY_DATA_LOADING_BACKEND = 800002


class InitGraphRequest(rpc.Request):
Expand All @@ -60,20 +62,19 @@ class InitGraphRequest(rpc.Request):
with shared memory.
"""

def __init__(self, graph_name, use_graphbolt):
def __init__(self, graph_name):
self._graph_name = graph_name
self._use_graphbolt = use_graphbolt

def __getstate__(self):
return self._graph_name, self._use_graphbolt
return self._graph_name

def __setstate__(self, state):
self._graph_name, self._use_graphbolt = state
self._graph_name = state

def process_request(self, server_state):
if server_state.graph is None:
server_state.graph = _get_graph_from_shared_mem(
self._graph_name, self._use_graphbolt
self._graph_name, server_state.use_graphbolt
)
return InitGraphResponse(self._graph_name)

Expand All @@ -91,6 +92,37 @@ def __setstate__(self, state):
self._graph_name = state


class QueryDataLoadingBackendRequest(rpc.Request):
"""Query the data loading backend."""

def __getstate__(self):
return None

def __setstate__(self, state):
pass

def process_request(self, server_state):
backend = (
DATA_LOADING_BACKEND_GRAPHBOLT
if server_state.use_graphbolt
else DATA_LOADING_BACKEND_DGL
)
return QueryDataLoadingBackendResponse(backend)


class QueryDataLoadingBackendResponse(rpc.Response):
"""Ack the query data loading backend request"""

def __init__(self, backend):
self._backend = backend

def __getstate__(self):
return self._backend

def __setstate__(self, state):
self._backend = state


def _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt):
if use_graphbolt:
return g.copy_to_shared_memory(graph_name)
Expand Down Expand Up @@ -473,6 +505,7 @@ def start(self):
kv_store=self,
local_g=self.client_g,
partition_book=self.gpb,
use_graphbolt=self.use_graphbolt,
)
print(
"start graph service on server {} for part {}".format(
Expand Down Expand Up @@ -529,8 +562,6 @@ class DistGraph:
part_config : str, optional
The path of partition configuration file generated by
:py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode.
use_graphbolt : bool, optional
Whether to load GraphBolt partition. Default: False.
Examples
--------
Expand Down Expand Up @@ -564,15 +595,11 @@ class DistGraph:
manually setting up servers and trainers. The setup is not fully tested yet.
"""

def __init__(
self, graph_name, gpb=None, part_config=None, use_graphbolt=False
):
def __init__(self, graph_name, gpb=None, part_config=None):
self.graph_name = graph_name
self._use_graphbolt = use_graphbolt
if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone":
assert (
use_graphbolt is False
), "GraphBolt is not supported in standalone mode."
# "GraphBolt is not supported in standalone mode."
self._use_graphbolt = False
assert (
part_config is not None
), "When running in the standalone model, the partition config file is required"
Expand Down Expand Up @@ -610,12 +637,16 @@ def __init__(
self._client.map_shared_data(self._gpb)
rpc.set_num_client(1)
else:
# Query the main server about whether GraphBolt is used.
rpc.send_request(0, QueryDataLoadingBackendRequest())
self._use_graphbolt = (
rpc.recv_response()._backend == DATA_LOADING_BACKEND_GRAPHBOLT
)

self._init(gpb)
# Tell the backup servers to load the graph structure from shared memory.
for server_id in range(self._client.num_servers):
rpc.send_request(
server_id, InitGraphRequest(graph_name, use_graphbolt)
)
rpc.send_request(server_id, InitGraphRequest(graph_name))
for server_id in range(self._client.num_servers):
rpc.recv_response()
self._client.barrier()
Expand Down Expand Up @@ -1832,3 +1863,8 @@ def edge_split(


rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)
rpc.register_service(
QUERY_DATA_LOADING_BACKEND,
QueryDataLoadingBackendRequest,
QueryDataLoadingBackendResponse,
)
10 changes: 9 additions & 1 deletion python/dgl/distributed/server_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@ class ServerState:
Total number of edges
partition_book : GraphPartitionBook
Graph Partition book
use_graphbolt : bool
Whether to use graphbolt for dataloading.
"""

def __init__(self, kv_store, local_g, partition_book):
def __init__(self, kv_store, local_g, partition_book, use_graphbolt=False):
self._kv_store = kv_store
self._graph = local_g
self.partition_book = partition_book
self._roles = {}
self._use_graphbolt = use_graphbolt

@property
def roles(self):
Expand All @@ -69,5 +72,10 @@ def graph(self):
def graph(self, graph):
self._graph = graph

@property
def use_graphbolt(self):
"""Whether to use graphbolt for dataloading."""
return self._use_graphbolt


_init_api("dgl.distributed.server_state")
8 changes: 4 additions & 4 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def run_client_empty(
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id
)
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
g = DistGraph(graph_name, gpb=gpb)
check_dist_graph_empty(g, num_clients, num_nodes, num_edges)


Expand Down Expand Up @@ -222,7 +222,7 @@ def run_client(
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id
)
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
g = DistGraph(graph_name, gpb=gpb)
check_dist_graph(
g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
)
Expand Down Expand Up @@ -322,7 +322,7 @@ def run_client_hierarchy(
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id
)
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
g = DistGraph(graph_name, gpb=gpb)
node_mask = F.tensor(node_mask)
edge_mask = F.tensor(edge_mask)
nodes = node_split(
Expand Down Expand Up @@ -742,7 +742,7 @@ def run_client_hetero(
gpb, graph_name, _, _ = load_partition_book(
"/tmp/dist_graph/{}.json".format(graph_name), part_id
)
g = DistGraph(graph_name, gpb=gpb, use_graphbolt=use_graphbolt)
g = DistGraph(graph_name, gpb=gpb)
check_dist_graph_hetero(
g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt
)
Expand Down
20 changes: 5 additions & 15 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def start_sample_client_shuffle(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
dist_graph = DistGraph("test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt
)
Expand Down Expand Up @@ -477,9 +475,7 @@ def start_hetero_sample_client(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
dist_graph = DistGraph("test_sampling", gpb=gpb)
assert "feat" in dist_graph.nodes["n1"].data
assert "feat" not in dist_graph.nodes["n2"].data
assert "feat" not in dist_graph.nodes["n3"].data
Expand Down Expand Up @@ -517,9 +513,7 @@ def start_hetero_etype_sample_client(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
dist_graph = DistGraph("test_sampling", gpb=gpb)
assert "feat" in dist_graph.nodes["n1"].data
assert "feat" not in dist_graph.nodes["n2"].data
assert "feat" not in dist_graph.nodes["n3"].data
Expand Down Expand Up @@ -876,9 +870,7 @@ def start_bipartite_sample_client(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
dist_graph = DistGraph("test_sampling", gpb=gpb)
assert "feat" in dist_graph.nodes["user"].data
assert "feat" in dist_graph.nodes["game"].data
if gpb is None:
Expand Down Expand Up @@ -911,9 +903,7 @@ def start_bipartite_etype_sample_client(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
dist_graph = DistGraph("test_sampling", gpb=gpb)
assert "feat" in dist_graph.nodes["user"].data
assert "feat" in dist_graph.nodes["game"].data

Expand Down
8 changes: 3 additions & 5 deletions tests/distributed/test_mp_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch as th
from dgl.data import CitationGraphDataset
from dgl.distributed import (
DATA_LOADING_BACKEND_DGL,
DATA_LOADING_BACKEND_GRAPHBOLT,
DistDataLoader,
DistGraph,
DistGraphServer,
Expand Down Expand Up @@ -104,7 +106,6 @@ def start_dist_dataloader(
"test_sampling",
gpb=gpb,
part_config=part_config,
use_graphbolt=use_graphbolt,
)

# Create sampler
Expand Down Expand Up @@ -443,7 +444,6 @@ def start_node_dataloader(
"test_sampling",
gpb=gpb,
part_config=part_config,
use_graphbolt=use_graphbolt,
)
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
Expand Down Expand Up @@ -763,9 +763,7 @@ def start_multiple_dataloaders(
use_graphbolt,
):
dgl.distributed.initialize(ip_config)
dist_g = dgl.distributed.DistGraph(
graph_name, part_config=part_config, use_graphbolt=use_graphbolt
)
dist_g = dgl.distributed.DistGraph(graph_name, part_config=part_config)
if dataloader_type == "node":
train_ids = th.arange(orig_g.num_nodes())
batch_size = orig_g.num_nodes() // 100
Expand Down

0 comments on commit b0080d5

Please sign in to comment.