Skip to content

Commit

Permalink
[DistGB] remove toindex() as torch tensor is always be expected (dmlc…
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Feb 26, 2024
1 parent 63541c8 commit 045beeb
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 70 deletions.
9 changes: 4 additions & 5 deletions python/dgl/distributed/dist_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import os

from .. import backend as F, utils
from .. import backend as F

from .dist_context import is_initialized
from .kvstore import get_kvstore
from .role import get_role
from .rpc import get_group_id
from .utils import totensor


def _default_init_data(shape, dtype):
Expand Down Expand Up @@ -200,13 +201,11 @@ def __del__(self):
self.kvstore.delete_data(self._name)

def __getitem__(self, idx):
idx = utils.toindex(idx)
idx = idx.tousertensor()
idx = totensor(idx)
return self.kvstore.pull(name=self._name, id_tensor=idx)

def __setitem__(self, idx, val):
idx = utils.toindex(idx)
idx = idx.tousertensor()
idx = totensor(idx)
# TODO(zhengda) how do we want to support broadcast (e.g., G.ndata['h'][idx] = 1).
self.kvstore.push(name=self._name, id_tensor=idx, data_tensor=val)

Expand Down
33 changes: 12 additions & 21 deletions python/dgl/distributed/graph_partition_book.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from abc import ABC

import numpy as np
import torch

from .. import backend as F, utils
from .. import backend as F
from .._ffi.ndarray import empty_shared_mem
from ..base import DGLError
from ..ndarray import exist_shared_mem_array
Expand Down Expand Up @@ -761,7 +762,6 @@ def map_to_per_etype(self, ids):

def map_to_homo_nid(self, ids, ntype):
"""Map per-node-type IDs to global node IDs in the homogeneous format."""
ids = utils.toindex(ids).tousertensor()
partids = self.nid2partid(ids, ntype)
typed_max_nids = F.zerocopy_from_numpy(self._typed_max_node_ids[ntype])
end_diff = F.gather_row(typed_max_nids, partids) - ids
Expand All @@ -772,7 +772,6 @@ def map_to_homo_nid(self, ids, ntype):

def map_to_homo_eid(self, ids, etype):
"""Map per-edge-type IDs to global edge IDs in the homoenegeous format."""
ids = utils.toindex(ids).tousertensor()
c_etype = self.to_canonical_etype(etype)
partids = self.eid2partid(ids, c_etype)
typed_max_eids = F.zerocopy_from_numpy(
Expand All @@ -786,32 +785,28 @@ def map_to_homo_eid(self, ids, etype):

def nid2partid(self, nids, ntype=DEFAULT_NTYPE):
"""From global node IDs to partition IDs"""
nids = utils.toindex(nids)
# [TODO][Rui] replace numpy with torch.
nids = nids.numpy()
if ntype == DEFAULT_NTYPE:
ret = np.searchsorted(
self._max_node_ids, nids.tonumpy(), side="right"
)
ret = np.searchsorted(self._max_node_ids, nids, side="right")
else:
ret = np.searchsorted(
self._typed_max_node_ids[ntype], nids.tonumpy(), side="right"
self._typed_max_node_ids[ntype], nids, side="right"
)
ret = utils.toindex(ret)
return ret.tousertensor()
return torch.from_numpy(ret)

def eid2partid(self, eids, etype=DEFAULT_ETYPE):
"""From global edge IDs to partition IDs"""
eids = utils.toindex(eids)
# [TODO][Rui] replace numpy with torch.
eids = eids.numpy()
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
ret = np.searchsorted(
self._max_edge_ids, eids.tonumpy(), side="right"
)
ret = np.searchsorted(self._max_edge_ids, eids, side="right")
else:
c_etype = self.to_canonical_etype(etype)
ret = np.searchsorted(
self._typed_max_edge_ids[c_etype], eids.tonumpy(), side="right"
self._typed_max_edge_ids[c_etype], eids, side="right"
)
ret = utils.toindex(ret)
return ret.tousertensor()
return torch.from_numpy(ret)

def partid2nids(self, partid, ntype=DEFAULT_NTYPE):
"""From partition ID to global node IDs"""
Expand Down Expand Up @@ -852,8 +847,6 @@ def nid2localnid(self, nids, partid, ntype=DEFAULT_NTYPE):
getting remote tensor of nid2localnid."
)

nids = utils.toindex(nids)
nids = nids.tousertensor()
if ntype == DEFAULT_NTYPE:
start = self._max_node_ids[partid - 1] if partid > 0 else 0
else:
Expand All @@ -870,8 +863,6 @@ def eid2localeid(self, eids, partid, etype=DEFAULT_ETYPE):
getting remote tensor of eid2localeid."
)

eids = utils.toindex(eids)
eids = eids.tousertensor()
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
start = self._max_edge_ids[partid - 1] if partid > 0 else 0
else:
Expand Down
49 changes: 32 additions & 17 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
sample_neighbors as local_sample_neighbors,
)
from ..subgraph import in_subgraph as local_in_subgraph
from ..utils import toindex
from .rpc import (
recv_responses,
register_service,
Expand Down Expand Up @@ -705,8 +704,6 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
"""
req_list = []
partition_book = g.get_partition_book()
if not isinstance(nodes, torch.Tensor):
nodes = toindex(nodes).tousertensor()
partition_id = partition_book.nid2partid(nodes)
local_nids = None
for pid in range(partition_book.num_partitions()):
Expand Down Expand Up @@ -900,11 +897,7 @@ def sample_etype_neighbors(
), "The sampled node type {} does not exist in the input graph".format(
ntype
)
if F.is_tensor(nodes[ntype]):
typed_nodes = nodes[ntype]
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
homo_nids.append(gpb.map_to_homo_nid(nodes[ntype], ntype))
nodes = F.cat(homo_nids, 0)

def issue_remote_req(node_ids):
Expand Down Expand Up @@ -1029,11 +1022,7 @@ def sample_neighbors(
assert (
ntype in g.ntypes
), "The sampled node type does not exist in the input graph"
if F.is_tensor(nodes[ntype]):
typed_nodes = nodes[ntype]
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
homo_nids.append(gpb.map_to_homo_nid(nodes[ntype], ntype))
nodes = F.cat(homo_nids, 0)
elif isinstance(nodes, dict):
assert len(nodes) == 1
Expand Down Expand Up @@ -1103,7 +1092,6 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access):
"""
req_list = []
partition_book = g.get_partition_book()
edges = toindex(edges).tousertensor()
partition_id = partition_book.eid2partid(edges)
local_eids = None
reorder_idx = []
Expand Down Expand Up @@ -1221,7 +1209,6 @@ def local_access(local_g, partition_book, local_nids):
def _distributed_get_node_property(g, n, issue_remote_req, local_access):
req_list = []
partition_book = g.get_partition_book()
n = toindex(n).tousertensor()
partition_id = partition_book.nid2partid(n)
local_nids = None
reorder_idx = []
Expand Down Expand Up @@ -1266,7 +1253,21 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access):


def in_degrees(g, v):
"""Get in-degrees"""
"""Get in-degrees
Parameters
----------
g : DistGraph
The distributed graph.
v : tensor
The node ID array.
Returns
-------
tensor
The in-degree array.
"""

def issue_remote_req(v, order_id):
return InDegreeRequest(v, order_id)
Expand All @@ -1278,7 +1279,21 @@ def local_access(local_g, partition_book, v):


def out_degrees(g, u):
"""Get out-degrees"""
"""Get out-degrees
Parameters
----------
g : DistGraph
The distributed graph.
u : tensor
The node ID array.
Returns
-------
tensor
The out-degree array.
"""

def issue_remote_req(u, order_id):
return OutDegreeRequest(u, order_id)
Expand Down
8 changes: 1 addition & 7 deletions python/dgl/distributed/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from .. import backend as F, utils
from .. import backend as F
from .._ffi.ndarray import empty_shared_mem

from . import rpc
Expand Down Expand Up @@ -1376,8 +1376,6 @@ def get_partid(self, name, id_tensor):
a vector storing the global data ID
"""
assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, "ID must be a vector."
# partition data
machine_id = self._part_policy[name].to_partid(id_tensor)
Expand All @@ -1399,8 +1397,6 @@ def push(self, name, id_tensor, data_tensor):
a tensor with the same row size of data ID
"""
assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, "ID must be a vector."
assert (
F.shape(id_tensor)[0] == F.shape(data_tensor)[0]
Expand Down Expand Up @@ -1452,8 +1448,6 @@ def pull(self, name, id_tensor):
a data tensor with the same row size of id_tensor.
"""
assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, "ID must be a vector."
if self._pull_handlers[name] is default_pull_handler: # Use fast-pull
part_id = self._part_policy[name].to_partid(id_tensor)
Expand Down
5 changes: 2 additions & 3 deletions python/dgl/distributed/nn/pytorch/sparse_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch as th

from .... import backend as F, utils
from .... import backend as F
from ...dist_tensor import DistTensor


Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(

def __call__(self, idx, device=th.device("cpu")):
"""
node_ids : th.tensor
idx : th.tensor
Index of the embeddings to collect.
device : th.device
Target device to put the collected embeddings.
Expand All @@ -109,7 +109,6 @@ def __call__(self, idx, device=th.device("cpu")):
Tensor
The requested node embeddings
"""
idx = utils.toindex(idx).tousertensor()
emb = self._tensor[idx].to(device, non_blocking=True)
if F.is_recording():
emb = F.attach_grad(emb)
Expand Down
23 changes: 23 additions & 0 deletions python/dgl/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
""" Utility functions for distributed training."""

import torch

from ..utils import toindex


def totensor(data):
"""Convert the given data to a tensor.
Parameters
----------
data : tensor, array, list or slice
Data to be converted.
Returns
-------
Tensor
Converted tensor.
"""
if isinstance(data, torch.Tensor):
return data
return toindex(data).tousertensor()
2 changes: 2 additions & 0 deletions tests/distributed/test_dist_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def check_binary_op(key1, key2, key3, op):
dist_g.edata[key3][i:i_end],
op(dist_g.edata[key1][i:i_end], dist_g.edata[key2][i:i_end]),
)
_ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int32)]
_ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int64)]


@unittest.skipIf(
Expand Down
Loading

0 comments on commit 045beeb

Please sign in to comment.