diff --git a/apps/kg/eval.py b/apps/kg/eval.py index 73e89e9153dc..646c1afaa3cd 100644 --- a/apps/kg/eval.py +++ b/apps/kg/eval.py @@ -8,7 +8,7 @@ import time import pickle -backend = os.environ.get('DGLBACKEND') +backend = os.environ.get('DGLBACKEND', 'pytorch') if backend.lower() == 'mxnet': from train_mxnet import load_model_from_checkpoint from train_mxnet import test diff --git a/apps/kg/models/general_models.py b/apps/kg/models/general_models.py index a66aad82151d..b9c95d36a3d4 100644 --- a/apps/kg/models/general_models.py +++ b/apps/kg/models/general_models.py @@ -2,7 +2,7 @@ import numpy as np import dgl.backend as F -backend = os.environ.get('DGLBACKEND') +backend = os.environ.get('DGLBACKEND', 'pytorch') if backend.lower() == 'mxnet': from .mxnet.tensor_models import logsigmoid from .mxnet.tensor_models import get_device diff --git a/apps/kg/tests/test_score.py b/apps/kg/tests/test_score.py index cfb4eb80367a..1a1f3589707c 100644 --- a/apps/kg/tests/test_score.py +++ b/apps/kg/tests/test_score.py @@ -5,7 +5,7 @@ import dgl.backend as F import dgl -backend = os.environ.get('DGLBACKEND') +backend = os.environ.get('DGLBACKEND', 'pytorch') if backend.lower() == 'mxnet': import mxnet as mx mx.random.seed(42) diff --git a/apps/kg/train.py b/apps/kg/train.py index a4edb26831dc..cc3decea9904 100644 --- a/apps/kg/train.py +++ b/apps/kg/train.py @@ -6,7 +6,7 @@ import logging import time -backend = os.environ.get('DGLBACKEND') +backend = os.environ.get('DGLBACKEND', 'pytorch') if backend.lower() == 'mxnet': import multiprocessing as mp from train_mxnet import load_model