Skip to content

Commit

Permalink
[Distributed] update Embedding API (dmlc#1847)
Browse files Browse the repository at this point in the history
* update.

* update Embedding.

* add comments.

* fix lint

Co-authored-by: Chao Ma <[email protected]>
  • Loading branch information
zheng-da and aksnzhy authored Jul 26, 2020
1 parent ec2e24b commit 6963d79
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 22 deletions.
2 changes: 1 addition & 1 deletion python/dgl/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .dist_graph import DistGraphServer, DistGraph, DistTensor, node_split, edge_split
from .partition import partition_graph, load_partition, load_partition_book
from .graph_partition_book import GraphPartitionBook, RangePartitionBook, PartitionPolicy
from .sparse_emb import SparseAdagrad, SparseNodeEmbedding
from .sparse_emb import SparseAdagrad, DistEmbedding

from .rpc import *
from .rpc_server import start_server
Expand Down
6 changes: 3 additions & 3 deletions python/dgl/distributed/dist_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ class DistTensor:
The dtype of the tensor
name : string
The name of the tensor.
part_policy : PartitionPolicy
The partition policy of the tensor
init_func : callable
The function to initialize data in the tensor.
part_policy : PartitionPolicy
The partition policy of the tensor
persistent : bool
Whether the created tensor is persistent.
'''
def __init__(self, g, shape, dtype, name=None, part_policy=None, init_func=None,
def __init__(self, g, shape, dtype, name=None, init_func=None, part_policy=None,
persistent=False):
self.kvstore = g._client
self._shape = shape
Expand Down
34 changes: 20 additions & 14 deletions python/dgl/distributed/sparse_emb.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,52 @@
"""Define sparse embedding and optimizer."""

from .. import backend as F
from .. import utils
from .dist_tensor import DistTensor
from .graph_partition_book import PartitionPolicy, NODE_PART_POLICY

class SparseNodeEmbedding:
''' Sparse embeddings in the distributed KVStore.
class DistEmbedding:
'''Embeddings in the distributed training.
The sparse embeddings are only used as node embeddings.
By default, the embeddings are created for nodes in the graph.
Parameters
----------
g : DistGraph
The distributed graph object.
num_embeddings : int
The number of embeddings
embedding_dim : int
The dimension size of embeddings.
name : str
The name of the embeddings
shape : tuple of int
The shape of the embedding. The first dimension should be the number of nodes.
initializer : callable
init_func : callable
The function to create the initial data.
part_policy : PartitionPolicy
The partition policy.
Examples
--------
>>> emb_init = lambda shape, dtype: F.zeros(shape, dtype, F.cpu())
>>> shape = (g.number_of_nodes(), 1)
>>> emb = dgl.distributed.SparseNodeEmbedding(g, 'emb1', shape, emb_init)
>>> emb = dgl.distributed.DistEmbedding(g, g.number_of_nodes(), 10)
>>> optimizer = dgl.distributed.SparseAdagrad([emb], lr=0.001)
>>> for blocks in dataloader:
>>> feats = emb(nids)
>>> loss = F.sum(feats + 1, 0)
>>> loss.backward()
>>> optimizer.step()
'''
def __init__(self, g, name, shape, initializer):
assert shape[0] == g.number_of_nodes()
part_policy = PartitionPolicy(NODE_PART_POLICY, g.get_partition_book())
g.ndata[name] = DistTensor(g, shape, F.float32, name, part_policy, initializer)
def __init__(self, g, num_embeddings, embedding_dim, name=None,
init_func=None, part_policy=None):
if part_policy is None:
part_policy = PartitionPolicy(NODE_PART_POLICY, g.get_partition_book())

self._tensor = g.ndata[name]
self._tensor = DistTensor(g, (num_embeddings, embedding_dim), F.float32, name,
init_func, part_policy)
self._trace = []

def __call__(self, idx):
idx = utils.toindex(idx).tousertensor()
emb = F.attach_grad(self._tensor[idx])
self._trace.append((idx, emb))
return emb
Expand Down Expand Up @@ -95,7 +101,7 @@ class SparseAdagrad:
Parameters
----------
params : list of SparseNodeEmbeddings
params : list of DistEmbeddings
The list of sparse embeddings.
lr : float
The learning rate.
Expand Down
7 changes: 3 additions & 4 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import DistGraphServer, DistGraph
from dgl.distributed import partition_graph, load_partition, load_partition_book, node_split, edge_split
from dgl.distributed import SparseAdagrad, SparseNodeEmbedding
from dgl.distributed import SparseAdagrad, DistEmbedding
from numpy.testing import assert_almost_equal
import backend as F
import math
Expand Down Expand Up @@ -120,8 +120,7 @@ def check_dist_graph(g, num_nodes, num_edges):

# Test sparse emb
try:
new_shape = (g.number_of_nodes(), 1)
emb = SparseNodeEmbedding(g, 'emb1', new_shape, emb_init)
emb = DistEmbedding(g, g.number_of_nodes(), 1, 'emb1', emb_init)
lr = 0.001
optimizer = SparseAdagrad([emb], lr=lr)
with F.record_grad():
Expand All @@ -142,7 +141,7 @@ def check_dist_graph(g, num_nodes, num_edges):
assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)))
assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))

emb = SparseNodeEmbedding(g, 'emb2', new_shape, emb_init)
emb = DistEmbedding(g, g.number_of_nodes(), 1, 'emb2', emb_init)
optimizer = SparseAdagrad([emb], lr=lr)
with F.record_grad():
feats1 = emb(nids)
Expand Down

0 comments on commit 6963d79

Please sign in to comment.