Skip to content

Commit

Permalink
[Backend] backend interface (dmlc#109)
Browse files Browse the repository at this point in the history
* backend interface

* small fix

* more comments to the data type dict

* WIP

* convert_to and narrow

* WIP

* pytorch and numpy backend; WIP on mxnet backend

* mxnet backend

* narrow

* Fix all usages

* fix for mx

* fix for mx

* fix mx

* fix mx

* fix mx

* fix mx

* fix mx

* fix mx

* fix mx

* revert jenkins

* add sparse_matrix api

* sparse matrix api

* some fixme

* Fix as requested
  • Loading branch information
jermainewang authored Nov 5, 2018
1 parent b420a5b commit 7241a9c
Show file tree
Hide file tree
Showing 23 changed files with 1,218 additions and 493 deletions.
65 changes: 51 additions & 14 deletions python/dgl/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,53 @@
from __future__ import absolute_import

import os

__backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower()
if __backend__ == 'numpy':
from .numpy import *
create_immutable_graph_index=None
elif __backend__ == 'pytorch':
from .pytorch import *
create_immutable_graph_index=None
elif __backend__ == 'mxnet':
from .mxnet import *
from .mxnet_immutable_graph_index import create_immutable_graph_index
else:
raise Exception("Unsupported backend %s" % __backend__)
import sys, os
import importlib

from . import backend

_enabled_apis = set()

def _gen_missing_api(api, mod_name):
def _missing_api(*args, **kwargs):
raise ImportError('API "%s" is not supported by backend "%s".'
' You can switch to other backends by setting'
' the DGLBACKEND environment.' % (api, mod_name))
return _missing_api

def _load_backend():
mod_name = os.environ.get('DGLBACKEND', 'pytorch').lower()
mod = importlib.import_module('.%s' % mod_name, __name__)
thismod = sys.modules[__name__]
for api in backend.__dict__.keys():
if api == 'data_type_dict':
# load data type
if api not in mod.__dict__:
raise ImportError('API "data_type_dict" is required but missing for'
' backend "%s".' % (mod_name))
data_type_dict = mod.__dict__[api]()
for name, dtype in data_type_dict.items():
setattr(thismod, name, dtype)
else:
# load functions
if api in mod.__dict__:
_enabled_apis.add(api)
setattr(thismod, api, mod.__dict__[api])
else:
setattr(thismod, api, _gen_missing_api(api, mod_name))

_load_backend()

def is_enabled(api):
"""Return true if the api is enabled by the current backend.
Parameters
----------
api : str
The api name.
Returns
-------
bool
True if the API is enabled by the current backend.
"""
return api in _enabled_apis
Loading

0 comments on commit 7241a9c

Please sign in to comment.