Skip to content

Commit

Permalink
set default backend. (dmlc#1104)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da authored Dec 13, 2019
1 parent adea4ba commit 1552090
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion apps/kg/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time
import pickle

backend = os.environ.get('DGLBACKEND')
backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet':
from train_mxnet import load_model_from_checkpoint
from train_mxnet import test
Expand Down
2 changes: 1 addition & 1 deletion apps/kg/models/general_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import dgl.backend as F

backend = os.environ.get('DGLBACKEND')
backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet':
from .mxnet.tensor_models import logsigmoid
from .mxnet.tensor_models import get_device
Expand Down
2 changes: 1 addition & 1 deletion apps/kg/tests/test_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dgl.backend as F
import dgl

backend = os.environ.get('DGLBACKEND')
backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet':
import mxnet as mx
mx.random.seed(42)
Expand Down
2 changes: 1 addition & 1 deletion apps/kg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import time

backend = os.environ.get('DGLBACKEND')
backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet':
import multiprocessing as mp
from train_mxnet import load_model
Expand Down

0 comments on commit 1552090

Please sign in to comment.