Skip to content

Commit

Permalink
[Distributed] Distributed node embedding and sparse optimizer (dmlc#2733
Browse files Browse the repository at this point in the history
)

* Draft for sparse emb

* add some notes

* Fix

* Add sparse optim for dist pytorch

* Update test

* Fix

* upd

* upd

* Fix

* Fix

* Fix bug

* add transductive exmpale

* Fix example

* Some fix

* Upd

* Fix lint

* lint

* lint

* lint

* upd

* Fix lint

* lint

* upd

* remove dead import

* update

* lint

* update unitest

* update example

* Add adam optimizer

* Add unitest and update data

* upd

* upd

* upd

* Fix docstring and fix some bug in example code

* Update rgcn readme

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
4 people authored May 3, 2021
1 parent 2d372e3 commit 975eb8f
Show file tree
Hide file tree
Showing 31 changed files with 1,696 additions and 167 deletions.
13 changes: 11 additions & 2 deletions docs/source/api/python/dgl.distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,23 @@ Distributed Tensor
.. autoclass:: DistTensor
:members: part_policy, shape, dtype, name

Distributed Embedding
Distributed Node Embedding
---------------------
.. currentmodule:: dgl.distributed.nn.pytorch

.. autoclass:: DistEmbedding
.. autoclass:: NodeEmbedding


Distributed embedding optimizer
-------------------------
.. currentmodule:: dgl.distributed.optim.pytorch

.. autoclass:: SparseAdagrad
:members: step

.. autoclass:: SparseAdam
:members: step

Distributed workload split
--------------------------

Expand Down
16 changes: 8 additions & 8 deletions docs/source/guide/distributed-apis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ This section covers the distributed APIs used in the training script. DGL provid
data structures and various APIs for initialization, distributed sampling and workload split.
For distributed training/inference, DGL provides three distributed data structures:
:class:`~dgl.distributed.DistGraph` for distributed graphs, :class:`~dgl.distributed.DistTensor` for
distributed tensors and :class:`~dgl.distributed.DistEmbedding` for distributed learnable embeddings.
distributed tensors and :class:`~dgl.distributed.nn.NodeEmbedding` for distributed learnable embeddings.

Initialization of the DGL distributed module
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -27,7 +27,7 @@ Typically, the initialization APIs should be invoked in the following order:
th.distributed.init_process_group(backend='gloo')
**Note**: If the training script contains user-defined functions (UDFs) that have to be invoked on
the servers (see the section of DistTensor and DistEmbedding for more details), these UDFs have to
the servers (see the section of DistTensor and NodeEmbedding for more details), these UDFs have to
be declared before :func:`~dgl.distributed.initialize`.

Distributed graph
Expand Down Expand Up @@ -125,7 +125,7 @@ in the cluster even if the :class:`~dgl.distributed.DistTensor` object disappear
tensor = dgl.distributed.DistTensor((g.number_of_nodes(), 10), th.float32, name='test')
**Note**: :class:`~dgl.distributed.DistTensor` creation is a synchronized operation. All trainers
have to invoke the creation and the creation succeeds only when all trainers call it.
have to invoke the creation and the creation succeeds only when all trainers call it.

A user can add a :class:`~dgl.distributed.DistTensor` to a :class:`~dgl.distributed.DistGraph`
object as one of the node data or edge data.
Expand Down Expand Up @@ -153,10 +153,10 @@ computation operators, such as sum and mean.
when a machine runs multiple servers. This may result in data corruption. One way to avoid concurrent
writes to the same row of data is to run one server process on a machine.

Distributed Embedding
Distributed NodeEmbedding
~~~~~~~~~~~~~~~~~~~~~

DGL provides :class:`~dgl.distributed.DistEmbedding` to support transductive models that require
DGL provides :class:`~dgl.distributed.nn.NodeEmbedding` to support transductive models that require
node embeddings. Creating distributed embeddings is very similar to creating distributed tensors.

.. code:: python
Expand All @@ -165,7 +165,7 @@ node embeddings. Creating distributed embeddings is very similar to creating dis
arr = th.zeros(shape, dtype=dtype)
arr.uniform_(-1, 1)
return arr
emb = dgl.distributed.DistEmbedding(g.number_of_nodes(), 10, init_func=initializer)
emb = dgl.distributed.nn.NodeEmbedding(g.number_of_nodes(), 10, init_func=initializer)
Internally, distributed embeddings are built on top of distributed tensors, and, thus, has
very similar behaviors to distributed tensors. For example, when embeddings are created, they
Expand All @@ -192,7 +192,7 @@ the other for dense model parameters, as shown in the code below:
optimizer.step()
sparse_optimizer.step()
**Note**: :class:`~dgl.distributed.DistEmbedding` is not an Pytorch nn module, so we cannot
**Note**: :class:`~dgl.distributed.nn.NodeEmbedding` is not an Pytorch nn module, so we cannot
get access to it from parameters of a Pytorch nn module.

Distributed sampling
Expand Down Expand Up @@ -252,7 +252,7 @@ the same as single-process sampling.
dataloader = dgl.sampling.NodeDataLoader(g, train_nid, sampler,
batch_size=batch_size, shuffle=True)
for batch in dataloader:
...
...
Split workloads
Expand Down
6 changes: 3 additions & 3 deletions docs/source/guide/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ For the training script, DGL provides distributed APIs that are similar to the o
mini-batch training. This makes distributed training require only small code modifications
from mini-batch training on a single machine. Below shows an example of training GraphSage
in a distributed fashion. The only code modifications are located on line 4-7:
1) initialize DGL's distributed module, 2) create a distributed graph object, and
1) initialize DGL's distributed module, 2) create a distributed graph object, and
3) split the training set and calculate the nodes for the local process.
The rest of the code, including sampler creation, model definition, training loops
are the same as :ref:`mini-batch training <guide-minibatch>`.
Expand All @@ -35,7 +35,7 @@ are the same as :ref:`mini-batch training <guide-minibatch>`.
# Create sampler
sampler = NeighborSampler(g, [10,25],
dgl.distributed.sample_neighbors,
dgl.distributed.sample_neighbors,
device)
dataloader = DistDataLoader(
Expand Down Expand Up @@ -85,7 +85,7 @@ Specifically, DGL's distributed training has three types of interacting processe
generate mini-batches for training.
* Trainers contain multiple classes to interact with servers. It has
:class:`~dgl.distributed.DistGraph` to get access to partitioned graph data and has
:class:`~dgl.distributed.DistEmbedding` and :class:`~dgl.distributed.DistTensor` to access
:class:`~dgl.distributed.nn.NodeEmbedding` and :class:`~dgl.distributed.DistTensor` to access
the node/edge features/embeddings. It has
:class:`~dgl.distributed.dist_dataloader.DistDataLoader` to
interact with samplers to get mini-batches.
Expand Down
12 changes: 6 additions & 6 deletions docs/source/guide_cn/distributed-apis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
本节介绍了在训练脚本中使用的分布式计算API。DGL提供了三种分布式数据结构和多种API,用于初始化、分布式采样和数据分割。
对于分布式训练/推断,DGL提供了三种分布式数据结构:用于分布式图的 :class:`~dgl.distributed.DistGraph`、
用于分布式张量的 :class:`~dgl.distributed.DistTensor` 和用于分布式可学习嵌入的
:class:`~dgl.distributed.DistEmbedding`。
:class:`~dgl.distributed.nn.NodeEmbedding`。

DGL分布式模块的初始化
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -24,7 +24,7 @@ DGL分布式模块的初始化
dgl.distributed.initialize('ip_config.txt')
th.distributed.init_process_group(backend='gloo')
**Note**: 如果训练脚本里包含需要在服务器(细节内容可以在下面的DistTensor和DistEmbedding章节里查看)上调用的用户自定义函数(UDF),
**Note**: 如果训练脚本里包含需要在服务器(细节内容可以在下面的DistTensor和NodeEmbedding章节里查看)上调用的用户自定义函数(UDF),
这些UDF必须在 :func:`~dgl.distributed.initialize` 之前被声明。

分布式图
Expand Down Expand Up @@ -138,7 +138,7 @@ DGL为分布式张量提供了类似于单机普通张量的接口,以访问
分布式嵌入
~~~~~~~~~~~~~~~~~~~~~

DGL提供 :class:`~dgl.distributed.DistEmbedding` 以支持需要节点嵌入的直推(transductive)模型。
DGL提供 :class:`~dgl.distributed.nn.NodeEmbedding` 以支持需要节点嵌入的直推(transductive)模型。
分布式嵌入的创建与分布式张量的创建非常相似。

.. code:: python
Expand All @@ -147,7 +147,7 @@ DGL提供 :class:`~dgl.distributed.DistEmbedding` 以支持需要节点嵌入的
arr = th.zeros(shape, dtype=dtype)
arr.uniform_(-1, 1)
return arr
emb = dgl.distributed.DistEmbedding(g.number_of_nodes(), 10, init_func=initializer)
emb = dgl.distributed.nn.NodeEmbedding(g.number_of_nodes(), 10, init_func=initializer)
在内部,分布式嵌入建立在分布式张量之上,因此,其行为与分布式张量非常相似。
例如,创建嵌入时,DGL会将它们分片并存储在集群中的所有计算机上。(分布式嵌入)可以通过名称唯一标识。
Expand All @@ -169,7 +169,7 @@ DGL提供了一个稀疏的Adagrad优化器 :class:`~dgl.distributed.SparseAdagr
optimizer.step()
sparse_optimizer.step()
**Note**: :class:`~dgl.distributed.DistEmbedding` 不是PyTorch的nn模块,因此用户无法从nn模块的参数访问它。
**Note**: :class:`~dgl.distributed.nn.NodeEmbedding` 不是PyTorch的nn模块,因此用户无法从nn模块的参数访问它。

分布式采样
~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -228,7 +228,7 @@ DGL提供了两个级别的API,用于对节点和边进行采样以生成小
dataloader = dgl.sampling.NodeDataLoader(g, train_nid, sampler,
batch_size=batch_size, shuffle=True)
for batch in dataloader:
...
...
分割数据集
Expand Down
4 changes: 2 additions & 2 deletions docs/source/guide_cn/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ DGL采用完全分布式的方法,可将数据和计算同时分布在一组
# 创建采样器
sampler = NeighborSampler(g, [10,25],
dgl.distributed.sample_neighbors,
dgl.distributed.sample_neighbors,
device)
dataloader = DistDataLoader(
Expand Down Expand Up @@ -74,7 +74,7 @@ DGL实现了一些分布式组件以支持分布式训练,下图显示了这
这些服务器一起工作以将图数据提供给训练器。请注意,一台机器可能同时运行多个服务器进程,以并行化计算和网络通信。
* *采样器进程* 与服务器进行交互,并对节点和边采样以生成用于训练的小批次数据。
* *训练器进程* 包含多个与服务器交互的类。它用 :class:`~dgl.distributed.DistGraph` 来获取被划分的图分区数据,
:class:`~dgl.distributed.DistEmbedding` 和
:class:`~dgl.distributed.nn.NodeEmbedding` 和
:class:`~dgl.distributed.DistTensor` 来获取节点/边特征/嵌入,用
:class:`~dgl.distributed.dist_dataloader.DistDataLoader` 与采样器进行交互以获得小批次数据。

Expand Down
52 changes: 49 additions & 3 deletions examples/pytorch/graphsage/experimental/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ The command below launches one training process on each machine and each trainin
python3 ~/workspace/dgl/tools/launch.py \
--workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
--num_trainers 1 \
--num_samplers 4 \
--num_samplers 0 \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
Expand All @@ -131,7 +131,7 @@ To run unsupervised training:
python3 ~/workspace/dgl/tools/launch.py \
--workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
--num_trainers 1 \
--num_samplers 4 \
--num_samplers 0 \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
Expand All @@ -144,13 +144,59 @@ By default, this code will run on CPU. If you have GPU support, you can just add
python3 ~/workspace/dgl/tools/launch.py \
--workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
--num_trainers 4 \
--num_samplers 4 \
--num_samplers 0 \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
"python3 train_dist.py --graph_name ogb-product --ip_config ip_config.txt --num_epochs 30 --batch_size 1000 --num_gpus 4"
```

To run supervised with transductive setting (nodes are initialized with node embedding)
```bash
python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
--num_trainers 4 \
--num_samplers 4 \
--num_servers 1 \
--num_samplers 0 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
"python3 train_dist_transductive.py --graph_name ogb-product --ip_config ip_config.txt --batch_size 1000 --num_gpu 4 --eval_every 5"
```

To run supervised with transductive setting using dgl distributed NodeEmbedding
```bash
python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
--num_trainers 4 \
--num_samplers 4 \
--num_servers 1 \
--num_samplers 0 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
"python3 train_dist_transductive.py --graph_name ogb-product --ip_config ip_config.txt --batch_size 1000 --num_gpu 4 --eval_every 5 --dgl_sparse"
```

To run unsupervised with transductive setting (nodes are initialized with node embedding)
```bash
python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
--num_trainers 4 \
--num_samplers 0 \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
"python3 train_dist_unsupervised_transductive.py --graph_name ogb-product --ip_config ip_config.txt --num_epochs 3 --batch_size 1000 --num_gpus 4"
```

To run unsupervised with transductive setting using dgl distributed NodeEmbedding
```bash
python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
--num_trainers 4 \
--num_samplers 0 \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
"python3 train_dist_unsupervised_transductive.py --graph_name ogb-product --ip_config ip_config.txt --num_epochs 3 --batch_size 1000 --num_gpus 4 --dgl_sparse"
```

**Note:** if you are using conda or other virtual environments on the remote machines, you need to replace `python3` in the command string (i.e. the last argument) with the path to the Python interpreter in that environment.

## Distributed code runs in the standalone mode
Expand Down
6 changes: 2 additions & 4 deletions examples/pytorch/graphsage/experimental/ip_config.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
172.31.19.1
172.31.23.205
172.31.29.175
172.31.16.98
172.31.2.66
172.31.1.191
14 changes: 8 additions & 6 deletions examples/pytorch/graphsage/experimental/train_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@
import torch.multiprocessing as mp
from torch.utils.data import DataLoader

def load_subtensor(g, seeds, input_nodes, device):
def load_subtensor(g, seeds, input_nodes, device, load_feat=True):
"""
Copys features and labels of a set of nodes onto GPU.
"""
batch_inputs = g.ndata['features'][input_nodes].to(device)
batch_inputs = g.ndata['features'][input_nodes].to(device) if load_feat else None
batch_labels = g.ndata['labels'][seeds].to(device)
return batch_inputs, batch_labels

class NeighborSampler(object):
def __init__(self, g, fanouts, sample_neighbors, device):
def __init__(self, g, fanouts, sample_neighbors, device, load_feat=True):
self.g = g
self.fanouts = fanouts
self.sample_neighbors = sample_neighbors
self.device = device
self.load_feat=load_feat

def sample_blocks(self, seeds):
seeds = th.LongTensor(np.asarray(seeds))
Expand All @@ -51,8 +52,9 @@ def sample_blocks(self, seeds):

input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID]
batch_inputs, batch_labels = load_subtensor(self.g, seeds, input_nodes, "cpu")
blocks[0].srcdata['features'] = batch_inputs
batch_inputs, batch_labels = load_subtensor(self.g, seeds, input_nodes, "cpu", self.load_feat)
if self.load_feat:
blocks[0].srcdata['features'] = batch_inputs
blocks[-1].dstdata['labels'] = batch_labels
return blocks

Expand Down Expand Up @@ -289,7 +291,7 @@ def main(args):
parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num_clients', type=int, help='The number of clients')
parser.add_argument('--n_classes', type=int, help='the number of classes')
parser.add_argument('--num_gpus', type=int, default=-1,
parser.add_argument('--num_gpus', type=int, default=-1,
help="the number of GPU device. Use -1 for CPU training")
parser.add_argument('--num_epochs', type=int, default=20)
parser.add_argument('--num_hidden', type=int, default=16)
Expand Down
Loading

0 comments on commit 975eb8f

Please sign in to comment.