diff --git a/python/dgl/distributed/__init__.py b/python/dgl/distributed/__init__.py index b981ecb4fa90..0831f789a275 100644 --- a/python/dgl/distributed/__init__.py +++ b/python/dgl/distributed/__init__.py @@ -5,7 +5,7 @@ from .dist_graph import DistGraphServer, DistGraph, DistTensor, node_split, edge_split from .partition import partition_graph, load_partition, load_partition_book from .graph_partition_book import GraphPartitionBook, RangePartitionBook, PartitionPolicy -from .sparse_emb import SparseAdagrad, SparseNodeEmbedding +from .sparse_emb import SparseAdagrad, DistEmbedding from .rpc import * from .rpc_server import start_server diff --git a/python/dgl/distributed/dist_tensor.py b/python/dgl/distributed/dist_tensor.py index cb052d67a046..0ea265308d9b 100644 --- a/python/dgl/distributed/dist_tensor.py +++ b/python/dgl/distributed/dist_tensor.py @@ -44,14 +44,14 @@ class DistTensor: The dtype of the tensor name : string The name of the tensor. - part_policy : PartitionPolicy - The partition policy of the tensor init_func : callable The function to initialize data in the tensor. + part_policy : PartitionPolicy + The partition policy of the tensor persistent : bool Whether the created tensor is persistent. ''' - def __init__(self, g, shape, dtype, name=None, part_policy=None, init_func=None, + def __init__(self, g, shape, dtype, name=None, init_func=None, part_policy=None, persistent=False): self.kvstore = g._client self._shape = shape diff --git a/python/dgl/distributed/sparse_emb.py b/python/dgl/distributed/sparse_emb.py index 379bfb3ba060..966b3b807215 100644 --- a/python/dgl/distributed/sparse_emb.py +++ b/python/dgl/distributed/sparse_emb.py @@ -1,30 +1,34 @@ """Define sparse embedding and optimizer.""" from .. import backend as F +from .. import utils from .dist_tensor import DistTensor from .graph_partition_book import PartitionPolicy, NODE_PART_POLICY -class SparseNodeEmbedding: - ''' Sparse embeddings in the distributed KVStore. +class DistEmbedding: + '''Embeddings in the distributed training. - The sparse embeddings are only used as node embeddings. + By default, the embeddings are created for nodes in the graph. Parameters ---------- g : DistGraph The distributed graph object. + num_embeddings : int + The number of embeddings + embedding_dim : int + The dimension size of embeddings. name : str The name of the embeddings - shape : tuple of int - The shape of the embedding. The first dimension should be the number of nodes. - initializer : callable + init_func : callable The function to create the initial data. + part_policy : PartitionPolicy + The partition policy. Examples -------- >>> emb_init = lambda shape, dtype: F.zeros(shape, dtype, F.cpu()) - >>> shape = (g.number_of_nodes(), 1) - >>> emb = dgl.distributed.SparseNodeEmbedding(g, 'emb1', shape, emb_init) + >>> emb = dgl.distributed.DistEmbedding(g, g.number_of_nodes(), 10) >>> optimizer = dgl.distributed.SparseAdagrad([emb], lr=0.001) >>> for blocks in dataloader: >>> feats = emb(nids) @@ -32,15 +36,17 @@ class SparseNodeEmbedding: >>> loss.backward() >>> optimizer.step() ''' - def __init__(self, g, name, shape, initializer): - assert shape[0] == g.number_of_nodes() - part_policy = PartitionPolicy(NODE_PART_POLICY, g.get_partition_book()) - g.ndata[name] = DistTensor(g, shape, F.float32, name, part_policy, initializer) + def __init__(self, g, num_embeddings, embedding_dim, name=None, + init_func=None, part_policy=None): + if part_policy is None: + part_policy = PartitionPolicy(NODE_PART_POLICY, g.get_partition_book()) - self._tensor = g.ndata[name] + self._tensor = DistTensor(g, (num_embeddings, embedding_dim), F.float32, name, + init_func, part_policy) self._trace = [] def __call__(self, idx): + idx = utils.toindex(idx).tousertensor() emb = F.attach_grad(self._tensor[idx]) self._trace.append((idx, emb)) return emb @@ -95,7 +101,7 @@ class SparseAdagrad: Parameters ---------- - params : list of SparseNodeEmbeddings + params : list of DistEmbeddings The list of sparse embeddings. lr : float The learning rate. diff --git a/tests/distributed/test_dist_graph_store.py b/tests/distributed/test_dist_graph_store.py index a922522c5edd..bfd5ca9e85d7 100644 --- a/tests/distributed/test_dist_graph_store.py +++ b/tests/distributed/test_dist_graph_store.py @@ -13,7 +13,7 @@ from dgl.data.utils import load_graphs, save_graphs from dgl.distributed import DistGraphServer, DistGraph from dgl.distributed import partition_graph, load_partition, load_partition_book, node_split, edge_split -from dgl.distributed import SparseAdagrad, SparseNodeEmbedding +from dgl.distributed import SparseAdagrad, DistEmbedding from numpy.testing import assert_almost_equal import backend as F import math @@ -120,8 +120,7 @@ def check_dist_graph(g, num_nodes, num_edges): # Test sparse emb try: - new_shape = (g.number_of_nodes(), 1) - emb = SparseNodeEmbedding(g, 'emb1', new_shape, emb_init) + emb = DistEmbedding(g, g.number_of_nodes(), 1, 'emb1', emb_init) lr = 0.001 optimizer = SparseAdagrad([emb], lr=lr) with F.record_grad(): @@ -142,7 +141,7 @@ def check_dist_graph(g, num_nodes, num_edges): assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1))) assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1))) - emb = SparseNodeEmbedding(g, 'emb2', new_shape, emb_init) + emb = DistEmbedding(g, g.number_of_nodes(), 1, 'emb2', emb_init) optimizer = SparseAdagrad([emb], lr=lr) with F.record_grad(): feats1 = emb(nids)