forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Backend] backend interface (dmlc#109)
* backend interface * small fix * more comments to the data type dict * WIP * convert_to and narrow * WIP * pytorch and numpy backend; WIP on mxnet backend * mxnet backend * narrow * Fix all usages * fix for mx * fix for mx * fix mx * fix mx * fix mx * fix mx * fix mx * fix mx * fix mx * revert jenkins * add sparse_matrix api * sparse matrix api * some fixme * Fix as requested
- Loading branch information
1 parent
b420a5b
commit 7241a9c
Showing
23 changed files
with
1,218 additions
and
493 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,53 @@ | ||
from __future__ import absolute_import | ||
|
||
import os | ||
|
||
__backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower() | ||
if __backend__ == 'numpy': | ||
from .numpy import * | ||
create_immutable_graph_index=None | ||
elif __backend__ == 'pytorch': | ||
from .pytorch import * | ||
create_immutable_graph_index=None | ||
elif __backend__ == 'mxnet': | ||
from .mxnet import * | ||
from .mxnet_immutable_graph_index import create_immutable_graph_index | ||
else: | ||
raise Exception("Unsupported backend %s" % __backend__) | ||
import sys, os | ||
import importlib | ||
|
||
from . import backend | ||
|
||
_enabled_apis = set() | ||
|
||
def _gen_missing_api(api, mod_name): | ||
def _missing_api(*args, **kwargs): | ||
raise ImportError('API "%s" is not supported by backend "%s".' | ||
' You can switch to other backends by setting' | ||
' the DGLBACKEND environment.' % (api, mod_name)) | ||
return _missing_api | ||
|
||
def _load_backend(): | ||
mod_name = os.environ.get('DGLBACKEND', 'pytorch').lower() | ||
mod = importlib.import_module('.%s' % mod_name, __name__) | ||
thismod = sys.modules[__name__] | ||
for api in backend.__dict__.keys(): | ||
if api == 'data_type_dict': | ||
# load data type | ||
if api not in mod.__dict__: | ||
raise ImportError('API "data_type_dict" is required but missing for' | ||
' backend "%s".' % (mod_name)) | ||
data_type_dict = mod.__dict__[api]() | ||
for name, dtype in data_type_dict.items(): | ||
setattr(thismod, name, dtype) | ||
else: | ||
# load functions | ||
if api in mod.__dict__: | ||
_enabled_apis.add(api) | ||
setattr(thismod, api, mod.__dict__[api]) | ||
else: | ||
setattr(thismod, api, _gen_missing_api(api, mod_name)) | ||
|
||
_load_backend() | ||
|
||
def is_enabled(api): | ||
"""Return true if the api is enabled by the current backend. | ||
Parameters | ||
---------- | ||
api : str | ||
The api name. | ||
Returns | ||
------- | ||
bool | ||
True if the API is enabled by the current backend. | ||
""" | ||
return api in _enabled_apis |
Oops, something went wrong.