From 78e0dae69343ba47077f575584bc4f68a58d8a0c Mon Sep 17 00:00:00 2001 From: Vasimuddin Md Date: Wed, 15 Dec 2021 07:32:31 +0530 Subject: [PATCH] [DistGNN, Graph partitioning] Libra partition (#3376) * added distgnn plus libra codebase * Dist application codes * added comments in partition code. changed the interface of partitioning call. * updated readme * create libra partitioning branch for the PR * removed disgnn files for first PR * updated kernel.cc * added libra_partition.cc and moved libra code from kernel.cc to libra_partition.cc * fixed lint error; merged libra2dgl.py and main_Libra.py to libra_partition.py; added graphsage/distgnn folder and partition script. * removed libra2dgl.py * fixed the lint error and cleaned the code. * revisions due to PR comments. added distgnn/tools contains partitions routines * update 2 PR revision I * fixed errors; also improved the runtime by 10x. * fixed minor lint error * fixed some more lints * PR revision II changed the interface of libra partition function * rewrite docstring Co-authored-by: Quan (Andy) Gan --- examples/pytorch/graphsage/distgnn/README.md | 32 + .../graphsage/distgnn/partition_graph.py | 62 ++ .../pytorch/graphsage/experimental/README.md | 35 + python/dgl/distgnn/__init__.py | 5 + python/dgl/distgnn/partition/__init__.py | 4 + .../dgl/distgnn/partition/libra_partition.py | 291 ++++++++ python/dgl/distgnn/tools/__init__.py | 4 + python/dgl/distgnn/tools/tools.py | 155 ++++ python/dgl/sparse.py | 100 +++ src/array/kernel.cc | 1 - src/array/libra_partition.cc | 690 ++++++++++++++++++ 11 files changed, 1378 insertions(+), 1 deletion(-) create mode 100644 examples/pytorch/graphsage/distgnn/README.md create mode 100644 examples/pytorch/graphsage/distgnn/partition_graph.py create mode 100644 python/dgl/distgnn/__init__.py create mode 100644 python/dgl/distgnn/partition/__init__.py create mode 100644 python/dgl/distgnn/partition/libra_partition.py create mode 100644 python/dgl/distgnn/tools/__init__.py create mode 100644 python/dgl/distgnn/tools/tools.py create mode 100644 src/array/libra_partition.cc diff --git a/examples/pytorch/graphsage/distgnn/README.md b/examples/pytorch/graphsage/distgnn/README.md new file mode 100644 index 000000000000..7cc6d7e138f9 --- /dev/null +++ b/examples/pytorch/graphsage/distgnn/README.md @@ -0,0 +1,32 @@ +## DistGNN vertex-cut based graph partitioning (using Libra) + +### How to run graph partitioning +```python partition_graph.py --dataset --num-parts --out-dir ``` + +Example: The following command-line creates 4 partitions of pubmed graph +``` python partition_graph.py --dataset pubmed --num-parts 4 --out-dir ./``` + +The ouptut partitions are created in the current directory in Libra_result_\/ folder. +The *upcoming DistGNN* application can directly use these partitions for distributed training. + +### How Libra partitioning works +Libra is a vertex-cut based graph partitioning method. It applies greedy heuristics to uniquely distribute the input graph edges among the partitions. It generates the partitions as a list of edges. Script ```libra_partition.py``` after generates the Libra partitions and converts the Libra output to DGL/DistGNN input format. + + +Note: Current Libra implementation is sequential. Extra overhead is paid due to the additional work of format conversion of the partitioned graph. + + +### Expected partitioning timinigs +Cora, Pubmed, Citeseer: < 10 sec (<10GB) +Reddit: ~150 sec (~ 25GB) +OGBN-Products: ~200 sec (~30GB) +Proteins: 1800 sec (Format conversion from public data takes time) (~100GB) +OGBN-Paper100M: 2500 sec (~200GB) + + +### Settings +Tested with: +Cent OS 7.6 +gcc v8.3.0 +PyTorch 1.7.1 +Python 3.7.10 diff --git a/examples/pytorch/graphsage/distgnn/partition_graph.py b/examples/pytorch/graphsage/distgnn/partition_graph.py new file mode 100644 index 000000000000..cd87c494cc14 --- /dev/null +++ b/examples/pytorch/graphsage/distgnn/partition_graph.py @@ -0,0 +1,62 @@ +r""" +Copyright (c) 2021 Intel Corporation + \file Graph partitioning + \brief Calls Libra - Vertex-cut based graph partitioner for distirbuted training + \author Vasimuddin Md , + Guixiang Ma + Sanchit Misra , + Ramanarayan Mohanty , + Sasikanth Avancha + Nesreen K. Ahmed +""" + + +import os +import sys +import numpy as np +import csv +from statistics import mean +import random +import time +import argparse +from load_graph import load_ogb +import dgl +from dgl.data import load_data +from dgl.distgnn.partition import partition_graph +from dgl.distgnn.tools import load_proteins +from dgl.base import DGLError + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument('--dataset', type=str, default='cora') + argparser.add_argument('--num-parts', type=int, default=2) + argparser.add_argument('--out-dir', type=str, default='./') + args = argparser.parse_args() + + dataset = args.dataset + num_community = args.num_parts + out_dir = 'Libra_result_' + dataset ## "Libra_result_" prefix is mandatory + resultdir = os.path.join(args.out_dir, out_dir) + + print("Input dataset for partitioning: ", dataset) + if args.dataset == 'ogbn-products': + print("Loading ogbn-products") + G, _ = load_ogb('ogbn-products') + elif args.dataset == 'ogbn-papers100M': + print("Loading ogbn-papers100M") + G, _ = load_ogb('ogbn-papers100M') + elif args.dataset == 'proteins': + G = load_proteins('proteins') + elif args.dataset == 'ogbn-arxiv': + print("Loading ogbn-arxiv") + G, _ = load_ogb('ogbn-arxiv') + else: + try: + G = load_data(args)[0] + except: + raise DGLError("Error: Dataset {} not found !!!".format(dataset)) + + print("Done loading the graph.", flush=True) + + partition_graph(num_community, G, resultdir) diff --git a/examples/pytorch/graphsage/experimental/README.md b/examples/pytorch/graphsage/experimental/README.md index be0110e35b56..a10e9b0188b4 100644 --- a/examples/pytorch/graphsage/experimental/README.md +++ b/examples/pytorch/graphsage/experimental/README.md @@ -1,3 +1,38 @@ +## DistGNN vertex-cut based graph partitioning (using Libra) + +### How to run graph partitioning +```python ../../../../python/dgl/distgnn/partition/main_Libra.py <#partitions>``` + +Example: The following command-line creates 4 partitions of pubmed graph +```python ../../../../python/dgl/distgnn/partition/main_Libra.py pubmed 4``` + +The ouptut partitions are created in the current directory in Libra_result_\/ folder. +The *upcoming DistGNN* application can directly use these partitions for distributed training. + +### How Libra partitioning works +Libra is a vertex-cut based graph partitioning method. It applies greedy heuristics to uniquely distribute the input graph edges among the partitions. It generates the partitions as a list of edges. Script ```main_Libra.py``` after getting the Libra partitions converts the Libra output to DGL/DistGNN input format. + + +Note: Current Libra implementation is sequential. Extra overhead is paid due to the additional work of format conversion of the partitioned graph. + + +### Expected partitioning timinigs +Cora, Pubmed, Citeseer: < 10 sec (<10GB) +Reddit: 1500 sec (~ 25GB) +OGBN-Products: ~2000 sec (~30GB) +Proteins: 18000 sec (Format conversion from public data takes time) (~100GB) +OGBN-Paper100M: 25000 sec (~200GB) + + +### Settings +Tested with: +Cent OS 7.6 +gcc v8.3.0 +PyTorch 1.7.1 +Python 3.7.10 + + + ## Distributed training This is an example of training GraphSage in a distributed fashion. Before training, please install some python libs by pip: diff --git a/python/dgl/distgnn/__init__.py b/python/dgl/distgnn/__init__.py new file mode 100644 index 000000000000..604014757967 --- /dev/null +++ b/python/dgl/distgnn/__init__.py @@ -0,0 +1,5 @@ +""" +This package contains DistGNN and Libra based graph partitioning tools. +""" +from . import partition +from . import tools diff --git a/python/dgl/distgnn/partition/__init__.py b/python/dgl/distgnn/partition/__init__.py new file mode 100644 index 000000000000..e884cf4dca97 --- /dev/null +++ b/python/dgl/distgnn/partition/__init__.py @@ -0,0 +1,4 @@ +""" +This package contains Libra graph partitioner. +""" +from .libra_partition import partition_graph diff --git a/python/dgl/distgnn/partition/libra_partition.py b/python/dgl/distgnn/partition/libra_partition.py new file mode 100644 index 000000000000..322c803899c6 --- /dev/null +++ b/python/dgl/distgnn/partition/libra_partition.py @@ -0,0 +1,291 @@ +r"""Libra partition functions. + +Libra partition is a vertex-cut based partitioning algorithm from +`Distributed Power-law Graph Computing: +Theoretical and Empirical Analysis +`__ +from Xie et al. +""" + +# Copyright (c) 2021 Intel Corporation +# \file distgnn/partition/libra_partition.py +# \brief Libra - Vertex-cut based graph partitioner for distributed training +# \author Vasimuddin Md , +# Guixiang Ma +# Sanchit Misra , +# Ramanarayan Mohanty , +# Sasikanth Avancha +# Nesreen K. Ahmed +# \cite Distributed Power-law Graph Computing: Theoretical and Empirical Analysis + +import os +import time +import json +import torch as th +from dgl import DGLGraph +from dgl.sparse import libra_vertex_cut +from dgl.sparse import libra2dgl_build_dict +from dgl.sparse import libra2dgl_set_lr +from dgl.sparse import libra2dgl_build_adjlist +from dgl.data.utils import save_graphs, save_tensors +from dgl.base import DGLError + + +def libra_partition(num_community, G, resultdir): + """ + Performs vertex-cut based graph partitioning and converts the partitioning + output to DGL input format. + + Parameters + ---------- + num_community : Number of partitions to create + G : Input graph to be partitioned + resultdir : Output location for storing the partitioned graphs + + Output + ------ + 1. Creates X partition folder as XCommunities (say, X=2, so, 2Communities) + XCommunities contains file name communityZ.txt per partition Z (Z <- 0 .. X-1); + each such file contains a list of edges assigned to that partition. + These files constitute the output of Libra graph partitioner + (An intermediate result of this function). + 2. The folder also contains partZ folders, each of these folders stores + DGL/DistGNN graphs for the Z partitions; + these graph files are used as input to DistGNN. + 3. The folder also contains a json file which contains partitions' information. + """ + + num_nodes = G.number_of_nodes() # number of nodes + num_edges = G.number_of_edges() # number of edges + print("Number of nodes in the graph: ", num_nodes) + print("Number of edges in the graph: ", num_edges) + + in_d = G.in_degrees() + out_d = G.out_degrees() + node_degree = in_d + out_d + edgenum_unassigned = node_degree.clone() + + u_t, v_t = G.edges() + weight_ = th.ones(u_t.shape[0], dtype=th.int64) + community_weights = th.zeros(num_community, dtype=th.int64) + + # self_loop = 0 + # for p, q in zip(u_t, v_t): + # if p == q: + # self_loop += 1 + # print("#self loops in the dataset: ", self_loop) + + # del G + + ## call to C/C++ code + out = th.zeros(u_t.shape[0], dtype=th.int32) + libra_vertex_cut(num_community, node_degree, edgenum_unassigned, community_weights, + u_t, v_t, weight_, out, num_nodes, num_edges, resultdir) + + print("Max partition size: ", int(community_weights.max())) + print(" ** Converting libra partitions to dgl graphs **") + fsize = int(community_weights.max()) + 1024 ## max edges in partition + # print("fsize: ", fsize, flush=True) + + node_map = th.zeros(num_community, dtype=th.int64) + indices = th.zeros(num_nodes, dtype=th.int64) + lrtensor = th.zeros(num_nodes, dtype=th.int64) + gdt_key = th.zeros(num_nodes, dtype=th.int64) + gdt_value = th.zeros([num_nodes, num_community], dtype=th.int64) + offset = th.zeros(1, dtype=th.int64) + ldt_ar = [] + + gg_ar = [DGLGraph() for i in range(num_community)] + part_nodes = [] + + print(">>> ", "num_nodes ", " ", "num_edges") + ## Iterator over number of partitions + for i in range(num_community): + g = gg_ar[i] + + a_t = th.zeros(fsize, dtype=th.int64) + b_t = th.zeros(fsize, dtype=th.int64) + ldt_key = th.zeros(fsize, dtype=th.int64) + ldt_ar.append(ldt_key) + + ## building node, parition dictionary + ## Assign local node ids and mapping to global node ids + ret = libra2dgl_build_dict(a_t, b_t, indices, ldt_key, gdt_key, gdt_value, + node_map, offset, num_community, i, fsize, resultdir) + + num_nodes_partition = int(ret[0]) + num_edges_partition = int(ret[1]) + part_nodes.append(num_nodes_partition) + print(">>> ", num_nodes_partition, " ", num_edges_partition) + g.add_edges(a_t[0:num_edges_partition], b_t[0:num_edges_partition]) + + ######################################################## + ## fixing lr - 1-level tree for the split-nodes + libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, num_community, num_nodes) + ######################################################## + #graph_name = dataset + graph_name = resultdir.split("_")[-1].split("/")[0] + part_method = 'Libra' + num_parts = num_community ## number of paritions/communities + num_hops = 0 + node_map_val = node_map.tolist() + edge_map_val = 0 + out_path = resultdir + + part_metadata = {'graph_name': graph_name, + 'num_nodes': G.number_of_nodes(), + 'num_edges': G.number_of_edges(), + 'part_method': part_method, + 'num_parts': num_parts, + 'halo_hops': num_hops, + 'node_map': node_map_val, + 'edge_map': edge_map_val} + ############################################################ + + for i in range(num_community): + g = gg_ar[0] + num_nodes_partition = part_nodes[i] + adj = th.zeros([num_nodes_partition, num_community - 1], dtype=th.int64) + inner_node = th.zeros(num_nodes_partition, dtype=th.int32) + lr_t = th.zeros(num_nodes_partition, dtype=th.int64) + ldt = ldt_ar[0] + + try: + feat = G.ndata['feat'] + except KeyError: + feat = G.ndata['features'] + + try: + labels = G.ndata['label'] + except KeyError: + labels = G.ndata['labels'] + + trainm = G.ndata['train_mask'].int() + testm = G.ndata['test_mask'].int() + valm = G.ndata['val_mask'].int() + + feat_size = feat.shape[1] + gfeat = th.zeros([num_nodes_partition, feat_size], dtype=feat.dtype) + + glabels = th.zeros(num_nodes_partition, dtype=labels.dtype) + gtrainm = th.zeros(num_nodes_partition, dtype=trainm.dtype) + gtestm = th.zeros(num_nodes_partition, dtype=testm.dtype) + gvalm = th.zeros(num_nodes_partition, dtype=valm.dtype) + + ## build remote node databse per local node + ## gather feats, train, test, val, and labels for each partition + libra2dgl_build_adjlist(feat, gfeat, adj, inner_node, ldt, gdt_key, + gdt_value, node_map, lr_t, lrtensor, num_nodes_partition, + num_community, i, feat_size, labels, trainm, testm, valm, + glabels, gtrainm, gtestm, gvalm, feat.shape[0]) + + + g.ndata['adj'] = adj ## database of remote clones + g.ndata['inner_node'] = inner_node ## split node '0' else '1' + g.ndata['feat'] = gfeat ## gathered features + g.ndata['lf'] = lr_t ## 1-level tree among split nodes + + g.ndata['label'] = glabels + g.ndata['train_mask'] = gtrainm + g.ndata['test_mask'] = gtestm + g.ndata['val_mask'] = gvalm + + # Validation code, run only small graphs + # for l in range(num_nodes_partition): + # index = int(ldt[l]) + # assert glabels[l] == labels[index] + # assert gtrainm[l] == trainm[index] + # assert gtestm[l] == testm[index] + # for j in range(feat_size): + # assert gfeat[l][j] == feat[index][j] + + print("Writing partition {} to file".format(i), flush=True) + + part = g + part_id = i + part_dir = os.path.join(out_path, "part" + str(part_id)) + node_feat_file = os.path.join(part_dir, "node_feat.dgl") + edge_feat_file = os.path.join(part_dir, "edge_feat.dgl") + part_graph_file = os.path.join(part_dir, "graph.dgl") + part_metadata['part-{}'.format(part_id)] = {'node_feats': node_feat_file, + 'edge_feats': edge_feat_file, + 'part_graph': part_graph_file} + os.makedirs(part_dir, mode=0o775, exist_ok=True) + save_tensors(node_feat_file, part.ndata) + save_graphs(part_graph_file, [part]) + + del g + del gg_ar[0] + del ldt + del ldt_ar[0] + + with open('{}/{}.json'.format(out_path, graph_name), 'w') as outfile: + json.dump(part_metadata, outfile, sort_keys=True, indent=4) + + print("Conversion libra2dgl completed !!!") + + +def partition_graph(num_community, G, resultdir): + """ + Performs vertex-cut based graph partitioning and converts the partitioning + output to DGL input format. + + Given a graph, this function will create a folder named ``XCommunities`` where ``X`` + stands for the number of communities. It will contain ``X`` files named + ``communityZ.txt`` for each partition Z (from 0 to X-1); + each such file contains a list of edges assigned to that partition. + These files constitute the output of Libra graph partitioner. + + The folder also contains X subfolders named ``partZ``, each of these folders stores + DGL/DistGNN graphs for partition Z; these graph files are used as input to + DistGNN. + + The folder also contains a json file which contains partitions' information. + + Currently we require the graph's node data to contain the following columns: + + * ``features`` for node features. + * ``label`` for node labels. + * ``train_mask`` as a boolean mask of training node set. + * ``val_mask`` as a boolean mask of validation node set. + * ``test_mask`` as a boolean mask of test node set. + + Parameters + ---------- + num_community : int + Number of partitions to create. + G : DGLGraph + Input graph to be partitioned. + resultdir : str + Output location for storing the partitioned graphs. + """ + + print("num partitions: ", num_community) + print("output location: ", resultdir) + + ## create ouptut directory + try: + os.makedirs(resultdir, mode=0o775, exist_ok=True) + except: + raise DGLError("Error: Could not create directory: ", resultdir) + + tic = time.time() + print("####################################################################") + print("Executing parititons: ", num_community) + ltic = time.time() + try: + resultdir = os.path.join(resultdir, str(num_community) + "Communities") + os.makedirs(resultdir, mode=0o775, exist_ok=True) + except: + raise DGLError("Error: Could not create sub-directory: ", resultdir) + + ## Libra partitioning + libra_partition(num_community, G, resultdir) + + ltoc = time.time() + print("Time taken by {} partitions {:0.4f} sec".format(num_community, ltoc - ltic)) + print() + + toc = time.time() + print("Generated ", num_community, " partitions in {:0.4f} sec".format(toc - tic), flush=True) + print("Partitioning completed successfully !!!") diff --git a/python/dgl/distgnn/tools/__init__.py b/python/dgl/distgnn/tools/__init__.py new file mode 100644 index 000000000000..25bdf4c65809 --- /dev/null +++ b/python/dgl/distgnn/tools/__init__.py @@ -0,0 +1,4 @@ +""" +This package contains extra routines related to Libra graph partitioner. +""" +from .tools import load_proteins diff --git a/python/dgl/distgnn/tools/tools.py b/python/dgl/distgnn/tools/tools.py new file mode 100644 index 000000000000..610315257304 --- /dev/null +++ b/python/dgl/distgnn/tools/tools.py @@ -0,0 +1,155 @@ +r""" +Copyright (c) 2021 Intel Corporation + \file distgnn/tools/tools.py + \brief Tools for use in Libra graph partitioner. + \author Vasimuddin Md +""" + +import os +import random +import requests +from scipy.io import mmread +import torch as th +import dgl +from dgl.base import DGLError +from dgl.data.utils import load_graphs, save_graphs, save_tensors + +def rep_per_node(prefix, num_community): + """ + Used on Libra partitioned data. + This function reports number of split-copes per node (replication) of + a partitioned graph + Parameters + ---------- + prefix: Partition folder location (contains replicationlist.csv) + num_community: number of partitions or communities + """ + ifile = os.path.join(prefix, 'replicationlist.csv') + fhandle = open(ifile, "r") + r_dt = {} + + fline = fhandle.readline() ## reading first line, contains the comment. + print(fline) + for line in fhandle: + if line[0] == '#': + raise DGLError("[Bug] Read Hash char in rep_per_node func.") + + node = line.strip('\n') + if r_dt.get(node, -100) == -100: + r_dt[node] = 1 + else: + r_dt[node] += 1 + + fhandle.close() + ## sanity checks + for v in r_dt.values(): + if v >= num_community: + raise DGLError("[Bug] Unexpected event in rep_per_node() in tools.py.") + + return r_dt + + +def download_proteins(): + """ + Downloads the proteins dataset + """ + print("Downloading dataset...") + print("This might a take while..") + url = "https://portal.nersc.gov/project/m1982/GNN/" + file_name = "subgraph3_iso_vs_iso_30_70length_ALL.m100.propermm.mtx" + url = url + file_name + try: + req = requests.get(url) + except: + raise DGLError("Error: Failed to download Proteins dataset!! Aborting..") + + with open("proteins.mtx", "wb") as handle: + handle.write(req.content) + + +def proteins_mtx2dgl(): + """ + This function converts Proteins dataset from mtx to dgl format. + """ + print("Converting mtx2dgl..") + print("This might a take while..") + a_mtx = mmread('proteins.mtx') + coo = a_mtx.tocoo() + u = th.tensor(coo.row, dtype=th.int64) + v = th.tensor(coo.col, dtype=th.int64) + g = dgl.DGLGraph() + + g.add_edges(u, v) + + n = g.number_of_nodes() + feat_size = 128 ## arbitrary number + feats = th.empty([n, feat_size], dtype=th.float32) + + ## arbitrary numbers + train_size = 1000000 + test_size = 500000 + val_size = 5000 + nlabels = 256 + + train_mask = th.zeros(n, dtype=th.bool) + test_mask = th.zeros(n, dtype=th.bool) + val_mask = th.zeros(n, dtype=th.bool) + label = th.zeros(n, dtype=th.int64) + + for i in range(train_size): + train_mask[i] = True + + for i in range(test_size): + test_mask[train_size + i] = True + + for i in range(val_size): + val_mask[train_size + test_size + i] = True + + for i in range(n): + label[i] = random.choice(range(nlabels)) + + g.ndata['feat'] = feats + g.ndata['train_mask'] = train_mask + g.ndata['test_mask'] = test_mask + g.ndata['val_mask'] = val_mask + g.ndata['label'] = label + + return g + + +def save(g, dataset): + """ + This function saves input dataset to dgl format + Parameters + ---------- + g : graph to be saved + dataset : output folder name + """ + print("Saving dataset..") + part_dir = os.path.join("./" + dataset) + node_feat_file = os.path.join(part_dir, "node_feat.dgl") + part_graph_file = os.path.join(part_dir, "graph.dgl") + os.makedirs(part_dir, mode=0o775, exist_ok=True) + save_tensors(node_feat_file, g.ndata) + save_graphs(part_graph_file, [g]) + print("Graph saved successfully !!") + + +def load_proteins(dataset): + """ + This function downloads, converts, and load Proteins graph dataset + Parameter + --------- + dataset: output folder name + """ + part_dir = dataset + graph_file = os.path.join(part_dir + "/graph.dgl") + + if not os.path.exists("proteins.mtx"): + download_proteins() + if not os.path.exists(graph_file): + g = proteins_mtx2dgl() + save(g, dataset) + ## load + graph = load_graphs(graph_file)[0][0] + return graph diff --git a/python/dgl/sparse.py b/python/dgl/sparse.py index 140f1d8af9fa..244031e680e3 100644 --- a/python/dgl/sparse.py +++ b/python/dgl/sparse.py @@ -703,4 +703,104 @@ def _csrmask(A, A_weights, B): """ return F.from_dgl_nd(_CAPI_DGLCSRMask(A, F.to_dgl_nd(A_weights), B)) + + +################################################################################################### +## Libra Graph Partition +def libra_vertex_cut(nc, node_degree, edgenum_unassigned, + community_weights, u, v, w, out, N, N_e, dataset): + """ + This function invokes C/C++ code for Libra based graph partitioning. + Parameter details are present in dgl/src/array/libra_partition.cc + """ + _CAPI_DGLLibraVertexCut(nc, + to_dgl_nd_for_write(node_degree), + to_dgl_nd_for_write(edgenum_unassigned), + to_dgl_nd_for_write(community_weights), + to_dgl_nd(u), + to_dgl_nd(v), + to_dgl_nd(w), + to_dgl_nd_for_write(out), + N, + N_e, + dataset) + + +def libra2dgl_build_dict(a, b, indices, ldt_key, gdt_key, gdt_value, node_map, + offset, nc, c, fsize, dataset): + """ + This function invokes C/C++ code for pre-processing Libra output. + After graph partitioning using Libra, during conversion from Libra output to DGL/DistGNN input, + this function creates dictionaries to assign local node ids to the partitioned nodes + and also to create a database of the split nodes. + Parameter details are present in dgl/src/array/libra_partition.cc + """ + ret = _CAPI_DGLLibra2dglBuildDict(to_dgl_nd_for_write(a), + to_dgl_nd_for_write(b), + to_dgl_nd_for_write(indices), + to_dgl_nd_for_write(ldt_key), + to_dgl_nd_for_write(gdt_key), + to_dgl_nd_for_write(gdt_value), + to_dgl_nd_for_write(node_map), + to_dgl_nd_for_write(offset), + nc, + c, + fsize, + dataset) + return ret + + +def libra2dgl_build_adjlist(feat, gfeat, adj, inner_node, ldt, gdt_key, + gdt_value, node_map, lr, lrtensor, num_nodes, + nc, c, feat_size, labels, trainm, testm, valm, + glabels, gtrainm, gtestm, gvalm, feat_shape): + """ + This function invokes C/C++ code for pre-processing Libra output. + After graph partitioning using Libra, once the local and global dictionaries are built, + for each node in each partition, this function copies the split node details from the + global dictionary. It also copies features, label, train, test, and validation information + for each node from the input graph to the corresponding partitions. + Parameter details are present in dgl/src/array/libra_partition.cc + """ + _CAPI_DGLLibra2dglBuildAdjlist(to_dgl_nd(feat), + to_dgl_nd_for_write(gfeat), + to_dgl_nd_for_write(adj), + to_dgl_nd_for_write(inner_node), + to_dgl_nd(ldt), + to_dgl_nd(gdt_key), + to_dgl_nd(gdt_value), + to_dgl_nd(node_map), + to_dgl_nd_for_write(lr), + to_dgl_nd(lrtensor), + num_nodes, + nc, + c, + feat_size, + to_dgl_nd(labels), + to_dgl_nd(trainm), + to_dgl_nd(testm), + to_dgl_nd(valm), + to_dgl_nd_for_write(glabels), + to_dgl_nd_for_write(gtrainm), + to_dgl_nd_for_write(gtestm), + to_dgl_nd_for_write(gvalm), + feat_shape) + + + +def libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, nc, Nn): + """ + This function invokes C/C++ code for pre-processing Libra output. + To prepare the graph partitions for DistGNN input, this function sets the leaf + and root (1-level tree) among the split copies (across different partitions) + of a node from input graph. + Parameter details are present in dgl/src/array/libra_partition.cc + """ + _CAPI_DGLLibra2dglSetLR(to_dgl_nd(gdt_key), + to_dgl_nd(gdt_value), + to_dgl_nd_for_write(lrtensor), + nc, + Nn) + + _init_api("dgl.sparse") diff --git a/src/array/kernel.cc b/src/array/kernel.cc index 37d0f154dd69..7e7a749cd41a 100644 --- a/src/array/kernel.cc +++ b/src/array/kernel.cc @@ -604,6 +604,5 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_FG_SDDMMTreeReduction") }); #endif // USE_TVM - } // namespace aten } // namespace dgl diff --git a/src/array/libra_partition.cc b/src/array/libra_partition.cc new file mode 100644 index 000000000000..c19babb531fa --- /dev/null +++ b/src/array/libra_partition.cc @@ -0,0 +1,690 @@ +/* +Copyright (c) 2021 Intel Corporation + \file distgnn/partition/main_Libra.py + \brief Libra - Vertex-cut based graph partitioner for distirbuted training + \author Vasimuddin Md , + Guixiang Ma + Sanchit Misra , + Ramanarayan Mohanty , + Sasikanth Avancha + Nesreen K. Ahmed +*/ + +#include +#include +#include +#include +#include +#include +#include + +#ifdef USE_TVM +#include +#endif // USE_TVM + +#include "kernel_decl.h" +#include "../c_api_common.h" +#include "./check.h" + +using namespace dgl::runtime; + +namespace dgl { +namespace aten { + +template +int32_t Ver2partition(IdType in_val, int64_t* node_map, int32_t num_parts) { + int32_t pos = 0; + for (int32_t p=0; p < num_parts; p++) { + if (in_val < node_map[p]) + return pos; + pos = pos + 1; + } + LOG(FATAL) << "Error: Unexpected output in Ver2partition!"; +} + +/*! \brief Identifies the lead loaded partition/community for a given edge assignment.*/ +int32_t LeastLoad(int64_t* community_edges, int32_t nc) { + std::vector loc; + int32_t min = 1e9; + for (int32_t i=0; i < nc; i++) { + if (community_edges[i] < min) { + min = community_edges[i]; + } + } + for (int32_t i=0; i < nc; i++) { + if (community_edges[i] == min) { + loc.push_back(i); + } + } + + int32_t r = RandomEngine::ThreadLocal()->RandInt(loc.size()); + CHECK(loc[r] < nc); + return loc[r]; +} + +/*! \brief Libra - vertexcut based graph partitioning. + It takes list of edges from input DGL graph and distributed them among nc partitions + During edge distribution, Libra assign a given edge to a partition based on the end vertices, + in doing so, it tries to minimized the splitting of the graph vertices. In case of conflict + Libra assigns an edge to the least loaded partition/community. + \param[in] nc Number of partitions/communities + \param[in] node_degree per node degree + \param[in] edgenum_unassigned node degree + \param[out] community_weights weight of the created partitions + \param[in] u src nodes + \param[in] v dst nodes + \param[out] w weight per edge + \param[out] out partition assignment of the edges + \param[in] N_n number of nodes in the input graph + \param[in] N_e number of edges in the input graph + \param[in] prefix output/partition storage location +*/ +template +void LibraVertexCut( + int32_t nc, + NDArray node_degree, + NDArray edgenum_unassigned, + NDArray community_weights, + NDArray u, + NDArray v, + NDArray w, + NDArray out, + int64_t N_n, + int64_t N_e, + const std::string& prefix) { + int32_t *out_ptr = out.Ptr(); + IdType2 *node_degree_ptr = node_degree.Ptr(); + IdType2 *edgenum_unassigned_ptr = edgenum_unassigned.Ptr(); + IdType *u_ptr = u.Ptr(); + IdType *v_ptr = v.Ptr(); + int64_t *w_ptr = w.Ptr(); + int64_t *community_weights_ptr = community_weights.Ptr(); + + std::vector > node_assignments(N_n); + std::vector replication_list; + // local allocations + int64_t *community_edges = new int64_t[nc](); + int64_t *cache = new int64_t[nc](); + + int64_t meter = static_cast(N_e/100); + for (int64_t i=0; i < N_e; i++) { + IdType u = u_ptr[i]; // edge end vertex 1 + IdType v = v_ptr[i]; // edge end vertex 2 + int64_t w = w_ptr[i]; // edge weight + + CHECK(u < N_n); + CHECK(v < N_n); + + if (i % meter == 0) { + fprintf(stderr, "."); fflush(0); + } + + if (node_assignments[u].size() == 0 && node_assignments[v].size() == 0) { + int32_t c = LeastLoad(community_edges, nc); + out_ptr[i] = c; + CHECK_LT(c, nc); + + community_edges[c]++; + community_weights_ptr[c] = community_weights_ptr[c] + w; + node_assignments[u].push_back(c); + if (u != v) + node_assignments[v].push_back(c); + + CHECK(node_assignments[u].size() <= nc) << + "[bug] 1. generated splits (u) are greater than nc!"; + CHECK(node_assignments[v].size() <= nc) << + "[bug] 1. generated splits (v) are greater than nc!"; + edgenum_unassigned_ptr[u]--; + edgenum_unassigned_ptr[v]--; + } else if (node_assignments[u].size() != 0 && node_assignments[v].size() == 0) { + for (uint32_t j=0; j < node_assignments[u].size(); j++) { + int32_t cind = node_assignments[u][j]; + cache[j] = community_edges[cind]; + } + int32_t cindex = LeastLoad(cache, node_assignments[u].size()); + int32_t c = node_assignments[u][cindex]; + out_ptr[i] = c; + community_edges[c]++; + community_weights_ptr[c] = community_weights_ptr[c] + w; + + node_assignments[v].push_back(c); + CHECK(node_assignments[v].size() <= nc) << + "[bug] 2. generated splits (v) are greater than nc!"; + edgenum_unassigned_ptr[u]--; + edgenum_unassigned_ptr[v]--; + } else if (node_assignments[v].size() != 0 && node_assignments[u].size() == 0) { + for (uint32_t j=0; j < node_assignments[v].size(); j++) { + int32_t cind = node_assignments[v][j]; + cache[j] = community_edges[cind]; + } + int32_t cindex = LeastLoad(cache, node_assignments[v].size()); + int32_t c = node_assignments[v][cindex]; + CHECK(c < nc) << "[bug] 2. partition greater than nc !!"; + out_ptr[i] = c; + + community_edges[c]++; + community_weights_ptr[c] = community_weights_ptr[c] + w; + + node_assignments[u].push_back(c); + CHECK(node_assignments[u].size() <= nc) << + "[bug] 3. generated splits (u) are greater than nc!"; + edgenum_unassigned_ptr[u]--; + edgenum_unassigned_ptr[v]--; + } else { + std::vector setv(nc), intersetv; + for (int32_t j=0; j < nc; j++) setv[j] = 0; + int32_t interset = 0; + + CHECK(node_assignments[u].size() <= nc) << + "[bug] 4. generated splits (u) are greater than nc!"; + CHECK(node_assignments[v].size() <= nc) << + "[bug] 4. generated splits (v) are greater than nc!"; + for (int32_t j=0; j < node_assignments[v].size(); j++) { + CHECK(node_assignments[v][j] < nc) << "[bug] 4. Part assigned (v) greater than nc!"; + setv[node_assignments[v][j]]++; + } + + for (int32_t j=0; j < node_assignments[u].size(); j++) { + CHECK(node_assignments[u][j] < nc) << "[bug] 4. Part assigned (u) greater than nc!"; + setv[node_assignments[u][j]]++; + } + + for (int32_t j=0; j < nc; j++) { + CHECK(setv[j] <= 2) << "[bug] 4. unexpected computed value !!!"; + if (setv[j] == 2) { + interset++; + intersetv.push_back(j); + } + } + if (interset) { + for (int32_t j=0; j < intersetv.size(); j++) { + int32_t cind = intersetv[j]; + cache[j] = community_edges[cind]; + } + int32_t cindex = LeastLoad(cache, intersetv.size()); + int32_t c = intersetv[cindex]; + CHECK(c < nc) << "[bug] 4. partition greater than nc !!"; + out_ptr[i] = c; + community_edges[c]++; + community_weights_ptr[c] = community_weights_ptr[c] + w; + edgenum_unassigned_ptr[u]--; + edgenum_unassigned_ptr[v]--; + } else { + if (node_degree_ptr[u] < node_degree_ptr[v]) { + for (uint32_t j=0; j < node_assignments[u].size(); j++) { + int32_t cind = node_assignments[u][j]; + cache[j] = community_edges[cind]; + } + int32_t cindex = LeastLoad(cache, node_assignments[u].size()); + int32_t c = node_assignments[u][cindex]; + CHECK(c < nc) << "[bug] 5. partition greater than nc !!"; + out_ptr[i] = c; + community_edges[c]++; + community_weights_ptr[c] = community_weights_ptr[c] + w; + + for (uint32_t j=0; j < node_assignments[v].size(); j++) { + CHECK(node_assignments[v][j] != c) << + "[bug] 5. duplicate partition (v) assignment !!"; + } + + node_assignments[v].push_back(c); + CHECK(node_assignments[v].size() <= nc) << + "[bug] 5. generated splits (v) greater than nc!!"; + replication_list.push_back(v); + edgenum_unassigned_ptr[u]--; + edgenum_unassigned_ptr[v]--; + } else { + for (uint32_t j=0; j < node_assignments[v].size(); j++) { + int32_t cind = node_assignments[v][j]; + cache[j] = community_edges[cind]; + } + int32_t cindex = LeastLoad(cache, node_assignments[v].size()); + int32_t c = node_assignments[v][cindex]; + CHECK(c < nc) << "[bug] 6. partition greater than nc !!"; + out_ptr[i] = c; + community_edges[c]++; + community_weights_ptr[c] = community_weights_ptr[c] + w; + for (uint32_t j=0; j < node_assignments[u].size(); j++) { + CHECK(node_assignments[u][j] != c) << + "[bug] 6. duplicate partition (u) assignment !!"; + } + if (u != v) + node_assignments[u].push_back(c); + + CHECK(node_assignments[u].size() <= nc) << + "[bug] 6. generated splits (u) greater than nc!!"; + replication_list.push_back(u); + edgenum_unassigned_ptr[u]--; + edgenum_unassigned_ptr[v]--; + } + } + } + } + delete cache; + + for (int64_t c=0; c < nc; c++) { + std::string path = prefix + "/community" + std::to_string(c) +".txt"; + + FILE *fp = fopen(path.c_str(), "w"); + CHECK_NE(fp, static_cast(NULL)) << "Error: can not open file: " << path.c_str(); + + for (int64_t i=0; i < N_e; i++) { + if (out_ptr[i] == c) + fprintf(fp, "%ld,%ld,%f\n", u_ptr[i], v_ptr[i], w_ptr[i]); + } + fclose(fp); + } + + std::string path = prefix + "/replicationlist.csv"; + FILE *fp = fopen(path.c_str(), "w"); + CHECK_NE(fp, static_cast(NULL)) << "Error: can not open file: " << path.c_str(); + + fprintf(fp, "## The Indices of Nodes that are replicated :: Header"); + printf("\nTotal replication: %ld\n", replication_list.size()); + + for (uint64_t i=0; i < replication_list.size(); i++) + fprintf(fp, "%ld\n", replication_list[i]); + + printf("Community weights:\n"); + for (int64_t c=0; c < nc; c++) + printf("%ld ", community_weights_ptr[c]); + printf("\n"); + + printf("Community edges:\n"); + for (int64_t c=0; c < nc; c++) + printf("%ld ", community_edges[c]); + printf("\n"); + + delete community_edges; + fclose(fp); +} + +DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibraVertexCut") +.set_body([] (DGLArgs args, DGLRetValue* rv) { + int32_t nc = args[0]; + NDArray node_degree = args[1]; + NDArray edgenum_unassigned = args[2]; + NDArray community_weights = args[3]; + NDArray u = args[4]; + NDArray v = args[5]; + NDArray w = args[6]; + NDArray out = args[7]; + int64_t N = args[8]; + int64_t N_e = args[9]; + std::string prefix = args[10]; + + ATEN_ID_TYPE_SWITCH(node_degree->dtype, IdType2, { + ATEN_ID_TYPE_SWITCH(u->dtype, IdType, { + LibraVertexCut(nc, + node_degree, + edgenum_unassigned, + community_weights, + u, + v, + w, + out, + N, + N_e, + prefix); + }); + }); +}); + + +/*! \brief + 1. Builds dictionary (ldt) for assigning local node IDs to nodes in the partitions + 2. Builds dictionary (gdt) for storing copies (local ID) of split nodes + These dictionaries will be used in the subsequesnt stages to setup tracking of split nodes + copies across the partition, setting up partition `ndata` dictionaries. + \param[out] a local src node ID of an edge in a partition + \param[out] b local dst node ID of an edge in a partition + \param[-] indices temporary memory, keeps track of global node ID to local node ID in a partition + \param[out] ldt_key per partition dict for storing global and local node IDs (consecutive) + \param[out] gdt_key global dict for storing number of local nodes (or split nodes) for a + given global node ID + \param[out] gdt_value global dict, stores local node IDs (due to split) across partitions + for a given global node ID + \param[out] node_map keeps track of range of local node IDs (consecutive) given to the nodes in + the partitions + \param[in, out] offset start of the range of local node IDs for this partition + \param[in] nc number of partitions/communities + \param[in] c current partition number + \param[in] fsize size of pre-allocated memory tensor + \param[in] prefix input Libra partition file location + */ +List Libra2dglBuildDict( + NDArray a, + NDArray b, + NDArray indices, + NDArray ldt_key, + NDArray gdt_key, + NDArray gdt_value, + NDArray node_map, + NDArray offset, + int32_t nc, + int32_t c, + int64_t fsize, + const std::string& prefix) { + int64_t *indices_ptr = indices.Ptr(); // 1D temp array + int64_t *ldt_key_ptr = ldt_key.Ptr(); // 1D local nodes <-> global nodes + int64_t *gdt_key_ptr = gdt_key.Ptr(); // 1D #split copies per node + int64_t *gdt_value_ptr = gdt_value.Ptr(); // 2D tensor + int64_t *node_map_ptr = node_map.Ptr(); // 1D tensor + int64_t *offset_ptr = offset.Ptr(); // 1D tensor + int32_t width = nc; + + int64_t *a_ptr = a.Ptr(); // stores local src and dst node ID, + int64_t *b_ptr = b.Ptr(); // to create the partition graph + + int64_t N_n = indices->shape[0]; + int64_t num_nodes = ldt_key->shape[0]; + + for (int64_t i=0; i < N_n; i++) { + indices_ptr[i] = -100; + } + + int64_t pos = 0; + int64_t edge = 0; + std::string path = prefix + "/community" + std::to_string(c) + ".txt"; + FILE *fp = fopen(path.c_str(), "r"); + CHECK_NE(fp, static_cast(NULL)) << "Error: can not open file: " << path.c_str(); + + while (!feof(fp) && edge < fsize) { + int64_t u, v; + float w; + fscanf(fp, "%ld,%ld,%f\n", &u, &v, &w); // reading an edge - the src and dst global node IDs + + if (indices_ptr[u] == -100) { // if already not assigned a local node ID, local node ID is + ldt_key_ptr[pos] = u; // already assigned for this global node ID + CHECK(pos < num_nodes); // Sanity check + indices_ptr[u] = pos++; // consecutive local node ID for a given global node ID + } + if (indices_ptr[v] == -100) { // if already not assigned a local node ID + ldt_key_ptr[pos] = v; + CHECK(pos < num_nodes); // Sanity check + indices_ptr[v] = pos++; + } + a_ptr[edge] = indices_ptr[u]; // new local ID for an edge + b_ptr[edge++] = indices_ptr[v]; // new local ID for an edge + } + CHECK(edge <= fsize) << "[Bug] memory allocated for #edges per partition is not enough."; + fclose(fp); + + List ret; + ret.push_back(Value(MakeValue(pos))); // returns total number of nodes in this partition + ret.push_back(Value(MakeValue(edge))); // returns total number of edges in this partition + + for (int64_t i=0; i < pos; i++) { + int64_t u = ldt_key_ptr[i]; // global node ID + // int64_t v = indices_ptr[u]; + int64_t v = i; // local node ID + int64_t *ind = &gdt_key_ptr[u]; // global dict, total number of local node IDs (an offset) + // as of now for a given global node ID + int64_t *ptr = gdt_value_ptr + u*width; + ptr[*ind] = offset_ptr[0] + v; // stores a local node ID for the global node ID + (*ind)++; + CHECK_NE(v, -100); + CHECK(*ind <= nc); + } + node_map_ptr[c] = offset_ptr[0] + pos; // since local node IDs for a partition are consecutive, + // we maintain the range of local node IDs like this + offset_ptr[0] += pos; + + return ret; +} + + +DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglBuildDict") +.set_body([] (DGLArgs args, DGLRetValue* rv) { + NDArray a = args[0]; + NDArray b = args[1]; + NDArray indices = args[2]; + NDArray ldt_key = args[3]; + NDArray gdt_key = args[4]; + NDArray gdt_value = args[5]; + NDArray node_map = args[6]; + NDArray offset = args[7]; + int32_t nc = args[8]; + int32_t c = args[9]; + int64_t fsize = args[10]; + std::string prefix = args[11]; + List ret = Libra2dglBuildDict(a, b, indices, ldt_key, gdt_key, + gdt_value, node_map, offset, + nc, c, fsize, prefix); + *rv = ret; +}); + + +/*! \brief sets up the 1-level tree among the clones of the split-nodes. + \param[in] gdt_key global dict for assigning consecutive node IDs to nodes across all the + partitions + \param[in] gdt_value global dict for assigning consecutive node IDs to nodes across all the + partition + \param[out] lrtensor keeps the root node ID of 1-level tree + \param[in] nc number of partitions/communities + \param[in] Nn number of nodes in the input graph + */ +void Libra2dglSetLR( + NDArray gdt_key, + NDArray gdt_value, + NDArray lrtensor, + int32_t nc, + int64_t Nn) { + int64_t *gdt_key_ptr = gdt_key.Ptr(); // 1D tensor + int64_t *gdt_value_ptr = gdt_value.Ptr(); // 2D tensor + int64_t *lrtensor_ptr = lrtensor.Ptr(); // 1D tensor + + int32_t width = nc; + int64_t cnt = 0; + int64_t avg_split_copy = 0, scnt = 0; + + for (int64_t i=0; i < Nn; i++) { + if (gdt_key_ptr[i] <= 0) { + cnt++; + } else { + int32_t val = RandomEngine::ThreadLocal()->RandInt(gdt_key_ptr[i]); + CHECK(val >= 0 && val < gdt_key_ptr[i]); + CHECK(gdt_key_ptr[i] <= nc); + + int64_t *ptr = gdt_value_ptr + i*width; + lrtensor_ptr[i] = ptr[val]; + } + if (gdt_key_ptr[i] > 1) { + avg_split_copy += gdt_key_ptr[i]; + scnt++; + } + } +} + +DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglSetLR") +.set_body([] (DGLArgs args, DGLRetValue* rv) { + NDArray gdt_key = args[0]; + NDArray gdt_value = args[1]; + NDArray lrtensor = args[2]; + int32_t nc = args[3]; + int64_t Nn = args[4]; + + Libra2dglSetLR(gdt_key, gdt_value, lrtensor, nc, Nn); +}); + + +/*! + \brief For each node in a partition, it creates a list of remote clone IDs; + also, for each node in a partition, it gathers the data (feats, label, trian, test) + from input graph. + \param[out] feat node features in current partition c + \param[in] gfeat input graph node features + \param[out] adj list of node IDs of remote clones + \param[out] inner_nodes marks whether a node is split or not + \param[in] ldt_key per partition dict for tracking global to local node IDs + \param[out] gdt_key global dict for storing number of local nodes (or split nodes) for a + given global node ID + \param[out] gdt_value global dict, stores local node IDs (due to split) across partitions + for a given global node ID + \param[in] node_map keeps track of range of local node IDs (consecutive) given to the nodes in + the partitions + \param[out] lr 1-level tree marking for local split nodes + \param[in] lrtensor global (all the partitions) 1-level tree + \param[in] num_nodes number of nodes in current partition + \param[in] nc number of partitions/communities + \param[in] c current partition/community + \param[in] feat_size node feature vector size + \param[out] labels local (for this partition) labels + \param[out] trainm local (for this partition) training nodes + \param[out] testm local (for this partition) testing nodes + \param[out] valm local (for this partition) validation nodes + \param[in] glabels global (input graph) labels + \param[in] gtrainm glabal (input graph) training nodes + \param[in] gtestm glabal (input graph) testing nodes + \param[in] gvalm glabal (input graph) validation nodes + \param[out] Nn number of nodes in the input graph + */ +template +void Libra2dglBuildAdjlist( + NDArray feat, + NDArray gfeat, + NDArray adj, + NDArray inner_node, + NDArray ldt_key, + NDArray gdt_key, + NDArray gdt_value, + NDArray node_map, + NDArray lr, + NDArray lrtensor, + int64_t num_nodes, + int32_t nc, + int32_t c, + int32_t feat_size, + NDArray labels , + NDArray trainm , + NDArray testm , + NDArray valm , + NDArray glabels, + NDArray gtrainm, + NDArray gtestm , + NDArray gvalm, + int64_t Nn) { + DType *feat_ptr = feat.Ptr(); // 2D tensor + DType *gfeat_ptr = gfeat.Ptr(); // 2D tensor + int64_t *adj_ptr = adj.Ptr(); // 2D tensor + int32_t *inner_node_ptr = inner_node.Ptr(); + int64_t *ldt_key_ptr = ldt_key.Ptr(); + int64_t *gdt_key_ptr = gdt_key.Ptr(); + int64_t *gdt_value_ptr = gdt_value.Ptr(); // 2D tensor + int64_t *node_map_ptr = node_map.Ptr(); + int64_t *lr_ptr = lr.Ptr(); + int64_t *lrtensor_ptr = lrtensor.Ptr(); + int32_t width = nc - 1; + + runtime::parallel_for(0, num_nodes, [&] (int64_t s, int64_t e) { + for (int64_t i=s; i < e; i++) { + int64_t k = ldt_key_ptr[i]; + int64_t v = i; + int64_t ind = gdt_key_ptr[k]; + + int64_t *adj_ptr_ptr = adj_ptr + v*width; + if (ind == 1) { + for (int32_t j=0; j < width; j++) adj_ptr_ptr[j] = -1; + inner_node_ptr[i] = 1; + lr_ptr[i] = -200; + } else { + lr_ptr[i] = lrtensor_ptr[k]; + int64_t *ptr = gdt_value_ptr + k*nc; + int64_t pos = 0; + CHECK(ind <= nc); + int32_t flg = 0; + for (int64_t j=0; j < ind; j++) { + if (ptr[j] == lr_ptr[i]) flg = 1; + if (c != Ver2partition(ptr[j], node_map_ptr, nc) ) + adj_ptr_ptr[pos++] = ptr[j]; + } + CHECK_EQ(flg, 1); + CHECK(pos == ind - 1); + for (; pos < width; pos++) adj_ptr_ptr[pos] = -1; + inner_node_ptr[i] = 0; + } + } + }); + + // gather + runtime::parallel_for(0, num_nodes, [&] (int64_t s, int64_t e) { + for (int64_t i=s; i < e; i++) { + int64_t k = ldt_key_ptr[i]; + int64_t ind = i*feat_size; + DType *optr = gfeat_ptr + ind; + DType *iptr = feat_ptr + k*feat_size; + + for (int32_t j=0; j < feat_size; j++) + optr[j] = iptr[j]; + } + + IdType *labels_ptr = labels.Ptr(); + IdType *glabels_ptr = glabels.Ptr(); + IdType2 *trainm_ptr = trainm.Ptr(); + IdType2 *gtrainm_ptr = gtrainm.Ptr(); + IdType2 *testm_ptr = testm.Ptr(); + IdType2 *gtestm_ptr = gtestm.Ptr(); + IdType2 *valm_ptr = valm.Ptr(); + IdType2 *gvalm_ptr = gvalm.Ptr(); + + for (int64_t i=0; i < num_nodes; i++) { + int64_t k = ldt_key_ptr[i]; + CHECK(k >=0 && k < Nn); + glabels_ptr[i] = labels_ptr[k]; + gtrainm_ptr[i] = trainm_ptr[k]; + gtestm_ptr[i] = testm_ptr[k]; + gvalm_ptr[i] = valm_ptr[k]; + } + }); +} + + +DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglBuildAdjlist") +.set_body([] (DGLArgs args, DGLRetValue* rv) { + NDArray feat = args[0]; + NDArray gfeat = args[1]; + NDArray adj = args[2]; + NDArray inner_node = args[3]; + NDArray ldt_key = args[4]; + NDArray gdt_key = args[5]; + NDArray gdt_value = args[6]; + NDArray node_map = args[7]; + NDArray lr = args[8]; + NDArray lrtensor = args[9]; + int64_t num_nodes = args[10]; + int32_t nc = args[11]; + int32_t c = args[12]; + int32_t feat_size = args[13]; + NDArray labels = args[14]; + NDArray trainm = args[15]; + NDArray testm = args[16]; + NDArray valm = args[17]; + NDArray glabels = args[18]; + NDArray gtrainm = args[19]; + NDArray gtestm = args[20]; + NDArray gvalm = args[21]; + int64_t Nn = args[22]; + + ATEN_FLOAT_TYPE_SWITCH(feat->dtype, DType, "Features", { + ATEN_ID_TYPE_SWITCH(trainm->dtype, IdType2, { + ATEN_ID_BITS_SWITCH((glabels->dtype).bits, IdType, { + Libra2dglBuildAdjlist(feat, gfeat, + adj, inner_node, + ldt_key, gdt_key, + gdt_value, + node_map, lr, + lrtensor, num_nodes, + nc, c, + feat_size, labels, + trainm, testm, + valm, glabels, + gtrainm, gtestm, + gvalm, Nn); + }); + }); + }); +}); + + +} // namespace aten +} // namespace dgl