Skip to content

Commit

Permalink
[Runtime] Scheduler and Executor (dmlc#140)
Browse files Browse the repository at this point in the history
* executor api

* draft executor interface

* WIP

* revert changes to avoid conflict with api change

* core scheduling logic

* WIP: build graph adj

* incidence matrix for in edges

* support incidence matrix for partial recv nodes

* improve

* build adjmat in scheduler

* graph store

* get degree bucketing schedule

* connect to c++ degree bucketing

* conceptual executor creation code

* executor comments

* fix

* more executor comments

* WIP: full send_and_recv schedule

* most schedulers

* simplify scheduler

* executors

* runtime

* builtin function base class

* adj indices and shape

* completely refactor scheduler

* rename and move bundled out to function.py

* use_edge_feature in msg func

* rewrite scheduler

* node edge executor

* connect with graph api

* handle zero degree

* misc

* fix test cases

* fix a good many bugs...

* remove old scheduler

* push and pull

* fix send recv

* c++ lint

* fix batched send recv

* hot fix for mxnet

* typo

* write back executor

* apply node edge

* clean up, doc string

* fix as requested

* refactor

* fix

* WIP

* WIP

* ir draft

* more on ir

* WIP: spmv schedule

* WIP

* recv schedule

* refactor

* WIP

* snr degree bucketing

* snr scheduler

* move prog to graph.py; rename

* unittest for send/recv

* remove some legacy codes

* WIP: update_all

* pass test_basics

* passed all current utests

* more utests; fix mx utest

* WIP: fixing zero deg initial value

* some tests

* fix 0deg problem

* fix mx

* fix mx

* some notes

* fix as requested
  • Loading branch information
lingfanyu authored and jermainewang committed Nov 22, 2018
1 parent 3e8b63e commit deb653f
Show file tree
Hide file tree
Showing 31 changed files with 2,395 additions and 1,010 deletions.
7 changes: 5 additions & 2 deletions include/dgl/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ namespace sched {

/*!
* \brief Generate degree bucketing schedule
* \param vids The destination vertex for messages
* \param msg_ids The edge id for each message
* \param vids The destination vertex for each message
* \param recv_ids The recv nodes (for checking zero degree nodes)
* \note If there are multiple messages going into the same destination vertex, then
* there will be multiple copies of the destination vertex in vids
* \return a vector of 5 IdArrays for degree bucketing. The 5 arrays are:
Expand All @@ -27,7 +29,8 @@ namespace sched {
* mids: message ids
* mid_section: number of messages in each bucket (used to split mids)
*/
std::vector<IdArray> DegreeBucketing(const IdArray& vids);
std::vector<IdArray> DegreeBucketing(const IdArray& msg_ids, const IdArray& vids,
const IdArray& recv_ids);

} // namespace sched

Expand Down
75 changes: 62 additions & 13 deletions python/dgl/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from __future__ import absolute_import

from collections import MutableMapping, namedtuple

import sys
import numpy as np

from . import backend as F
from .base import DGLError, dgl_warning
from .init import zero_initializer
from . import utils

import sys


class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
"""The column scheme.
Expand Down Expand Up @@ -160,11 +160,6 @@ def create(data):
else:
return Column(data)


def zero_initializer(shape, dtype, ctx):
return F.zeros(shape, dtype, ctx)


class Frame(MutableMapping):
"""The columnar storage for node/edge features.
Expand Down Expand Up @@ -320,7 +315,8 @@ def add_column(self, name, scheme, ctx):
if self.get_initializer(name) is None:
self._warn_and_set_initializer()
init_data = self.get_initializer(name)(
(self.num_rows,) + scheme.shape, scheme.dtype, ctx)
(self.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(0, self.num_rows))
self._columns[name] = Column(init_data, scheme)

def add_rows(self, num_rows):
Expand All @@ -334,19 +330,18 @@ def add_rows(self, num_rows):
num_rows : int
The number of new rows
"""
self._num_rows += num_rows

feat_placeholders = {}
for key, col in self._columns.items():
scheme = col.scheme
ctx = F.context(col.data)
if self.get_initializer(key) is None:
self._warn_and_set_initializer()
new_data = self.get_initializer(key)(
(num_rows,) + scheme.shape, scheme.dtype, ctx)
(num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(self._num_rows, self._num_rows + num_rows))
feat_placeholders[key] = new_data

self._append(Frame(feat_placeholders))
self._num_rows += num_rows

def update_column(self, name, data):
"""Add or replace the column with the given name and data.
Expand Down Expand Up @@ -476,6 +471,21 @@ def set_initializer(self, initializer, column=None):
"""
self._frame.set_initializer(initializer, column=column)

def get_initializer(self, column=None):
"""Get the initializer for empty values for the given column.
Parameters
----------
column : str
The column
Returns
-------
callable
The initializer
"""
return self._frame.get_initializer(column)

def index(self):
"""Return the index object.
Expand Down Expand Up @@ -553,6 +563,12 @@ def __getitem__(self, key):
"""
if isinstance(key, str):
return self.select_column(key)
elif isinstance(key, slice) and key == slice(0, self.num_rows):
# shortcut for selecting all the rows
return self
elif isinstance(key, utils.Index) and key.is_slice(0, self.num_rows):
# shortcut for selecting all the rows
return self
else:
return self.select_rows(key)

Expand Down Expand Up @@ -616,6 +632,12 @@ def __setitem__(self, key, val):
"""
if isinstance(key, str):
self.update_column(key, val, inplace=False)
elif isinstance(key, slice) and key == slice(0, self.num_rows):
# shortcut for updating all the rows
return self.update(val)
elif isinstance(key, utils.Index) and key.is_slice(0, self.num_rows):
# shortcut for selecting all the rows
return self.update(val)
else:
self.update_rows(key, val, inplace=False)

Expand Down Expand Up @@ -807,6 +829,33 @@ def _clear_cache(self):
self._index = None
self._index_or_slice = None

def frame_like(other, num_rows):
"""Create a new frame that has the same scheme as the given one.
Parameters
----------
other : Frame
The given frame.
num_rows : int
The number of rows of the new one.
Returns
-------
Frame
The new frame.
"""
# TODO(minjie): scheme is not inherited at the moment. Fix this
# when moving per-col initializer to column scheme.
newf = Frame(num_rows=num_rows)
# set global initializr
if other.get_initializer() is None:
other._warn_and_set_initializer()
newf._default_initializer = other._default_initializer
# set per-col initializer
for key in other.keys():
newf.set_initializer(other.get_initializer(key), key)
return newf

def merge_frames(frames, indices, max_index, reduce_func):
"""Merge a list of frames.
Expand Down
1 change: 1 addition & 0 deletions python/dgl/function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@

from .message import *
from .reducer import *
from .base import *
61 changes: 24 additions & 37 deletions python/dgl/function/base.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,31 @@
"""Built-in functions."""
"""Built-in function base class"""
from __future__ import absolute_import

from functools import update_wrapper
class BuiltinFunction(object):
"""Base builtin function class."""

__all__ = ['create_bundled_function_class']
def __call__(self):
"""Regular computation of this builtin function
def create_bundled_function_class(name, cls):
class Bundled(cls):
def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
self.fn_list = fn_list
This will be used when optimization is not available.
"""
raise NotImplementedError

def is_spmv_supported(self, *args, **kwargs):
return all(isinstance(fn, cls) and
fn.is_spmv_supported(*args, **kwargs)
for fn in self.fn_list)
@property
def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError

def __call__(self, *args, **kwargs):
ret = {}
for fn in self.fn_list:
result = fn(*args, **kwargs)
ret.update(result)
return ret
class BundledFunction(object):
def __init__(self, fn_list):
self.fn_list = fn_list

def name(self):
return "bundled"
def __call__(self, *args, **kwargs):
ret = {}
for fn in self.fn_list:
ret.update(fn(*args, **kwargs))
return ret

# Fake the names for introspection
Bundled.__module__ = cls.__module__
Bundled.__name__ = name
Bundled.__qualname__ = name

for method_name in ('__init__', '__call__', 'is_spmv_supported', 'name'):
method = getattr(Bundled, method_name)
method.__qualname__ = '{}.{}'.format(Bundled.__qualname__, method_name)

for method_name in ('__call__', 'is_spmv_supported', 'name'):
method = getattr(Bundled, method_name)
method = update_wrapper(method,
cls.__dict__[method.__name__],
('__module__', '__doc__', '__annotations__'))

return Bundled
@property
def name(self):
return "bundled"
36 changes: 27 additions & 9 deletions python/dgl/function/message.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Built-in message function."""
from __future__ import absolute_import

from .base import BuiltinFunction
import operator
import dgl.backend as F
from .base import create_bundled_function_class

__all__ = ["src_mul_edge", "copy_src", "copy_edge"]


class MessageFunction(object):
class MessageFunction(BuiltinFunction):
"""Base builtin message function class."""

def __call__(self, edges):
Expand All @@ -18,6 +18,7 @@ def __call__(self, edges):
"""
raise NotImplementedError

@property
def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError
Expand All @@ -26,9 +27,9 @@ def is_spmv_supported(self, g):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError


BundledMessageFunction = create_bundled_function_class(
'BundledMessageFunction', MessageFunction)
@property
def use_edge_feature(self):
raise NotImplementedError


def _is_spmv_supported_node_feat(g, field):
Expand Down Expand Up @@ -64,15 +65,22 @@ def is_spmv_supported(self, g):
def __call__(self, edges):
src_data = edges.src[self.src_field]
edata = edges.data[self.edge_field]
src_dim = F.ndim(src_data)
eshape = F.shape(edata)[0]
ret = self.mul_op(edges.src[self.src_field],
F.reshape(edges.data[self.edge_field], (eshape,) + (1,) * (src_dim - 1)))
if F.ndim(edata) == 1:
# edge feature is a scalar, unsqueeze dims of len 1
src_dim = F.ndim(src_data)
new_eshape = (F.shape(edata)[0],) + (1,) * (src_dim - 1)
edata = F.reshape(edata, new_eshape)
ret = self.mul_op(src_data, edata)
return {self.out_field : ret}

@property
def name(self):
return "src_mul_edge"

@property
def use_edge_feature(self):
return True

class CopySrcMessageFunction(MessageFunction):
def __init__(self, src_field, out_field):
self.src_field = src_field
Expand All @@ -84,9 +92,14 @@ def is_spmv_supported(self, g):
def __call__(self, edges):
return {self.out_field : edges.src[self.src_field]}

@property
def name(self):
return "copy_src"

@property
def use_edge_feature(self):
return False

class CopyEdgeMessageFunction(MessageFunction):
def __init__(self, edge_field=None, out_field=None):
self.edge_field = edge_field
Expand All @@ -100,9 +113,14 @@ def is_spmv_supported(self, g):
def __call__(self, edges):
return {self.out_field : edges.data[self.edge_field]}

@property
def name(self):
return "copy_edge"

@property
def use_edge_feature(self):
return True


def src_mul_edge(src, edge, out):
"""Builtin message function that computes message by multiplying source node features
Expand Down
10 changes: 4 additions & 6 deletions python/dgl/function/reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from __future__ import absolute_import

from .. import backend as F
from .base import create_bundled_function_class
from .base import BuiltinFunction

__all__ = ["sum", "max"]

class ReduceFunction(object):
class ReduceFunction(BuiltinFunction):
"""Base builtin reduce function class."""

def __call__(self, nodes):
Expand All @@ -16,6 +16,7 @@ def __call__(self, nodes):
"""
raise NotImplementedError

@property
def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError
Expand All @@ -25,10 +26,6 @@ def is_spmv_supported(self):
raise NotImplementedError


BundledReduceFunction = create_bundled_function_class(
'BundledReduceFunction', ReduceFunction)


class SimpleReduceFunction(ReduceFunction):
"""Builtin reduce function that aggregates a single field into another
single field."""
Expand All @@ -45,6 +42,7 @@ def is_spmv_supported(self):
def __call__(self, nodes):
return {self.out_field : self.op(nodes.mailbox[self.msg_field], 1)}

@property
def name(self):
return self._name

Expand Down
Loading

0 comments on commit deb653f

Please sign in to comment.