Skip to content

Commit

Permalink
Refactor CPP installation checks to allow all modules to be imported …
Browse files Browse the repository at this point in the history
…for unit tests (facebookresearch#149)

Summary:
Pull Request resolved: facebookresearch#149

We have some checks so that Filament can run when the C++ code isn't installed, but unittest still tries to important `train_gpu.py` directly and fails. This moves the check *inside* `train_gpu`.

Reviewed By: lw

Differential Revision: D22073228

fbshipit-source-id: b75980c55c95eea31f82c786d9f6f35d48a146dd
  • Loading branch information
adamlerer authored and facebook-github-bot committed Jun 20, 2020
1 parent 6530836 commit 197e9c4
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
9 changes: 5 additions & 4 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from torchbiggraph.eval import do_eval
from torchbiggraph.partitionserver import run_partition_server
from torchbiggraph.stats import SerializedStats
from torchbiggraph.train import GPU_INSTALLED, train
from torchbiggraph.train import train
from torchbiggraph.train_gpu import CPP_INSTALLED
from torchbiggraph.util import (
SubprocessInitializer,
call_one_after_the_other,
Expand Down Expand Up @@ -491,15 +492,15 @@ def test_partitioned(self):
self.assertCheckpointWritten(train_config, version=1)
do_eval(eval_config, subprocess_init=self.subprocess_init)

@unittest.skipIf(not torch.cuda.is_available() or not GPU_INSTALLED, "No GPU")
@unittest.skipIf(not torch.cuda.is_available() or not CPP_INSTALLED, "No GPU")
def test_gpu(self):
self._test_gpu()

@unittest.skipIf(not torch.cuda.is_available() or not GPU_INSTALLED, "No GPU")
@unittest.skipIf(not torch.cuda.is_available() or not CPP_INSTALLED, "No GPU")
def test_gpu_half(self):
self._test_gpu(do_half_precision=True)

@unittest.skipIf(not torch.cuda.is_available() or not GPU_INSTALLED, "No GPU")
@unittest.skipIf(not torch.cuda.is_available() or not CPP_INSTALLED, "No GPU")
def test_gpu_1partition(self):
self._test_gpu(num_partitions=1)

Expand Down
15 changes: 1 addition & 14 deletions torchbiggraph/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchbiggraph.config import ConfigFileLoader, ConfigSchema, add_to_sys_path
from torchbiggraph.model import MultiRelationEmbedder
from torchbiggraph.train_cpu import TrainingCoordinator
from torchbiggraph.train_gpu import GPUTrainingCoordinator
from torchbiggraph.types import SINGLE_TRAINER, Rank
from torchbiggraph.util import (
SubprocessInitializer,
Expand All @@ -22,14 +23,6 @@
)


try:
from torchbiggraph.train_gpu import GPUTrainingCoordinator

GPU_INSTALLED = True
except ImportError:
GPU_INSTALLED = False


logger = logging.getLogger("torchbiggraph")
dist_logger = logging.LoggerAdapter(logger, {"distributed": True})

Expand All @@ -42,12 +35,6 @@ def train(
rank: Rank = SINGLE_TRAINER,
subprocess_init: Optional[Callable[[], None]] = None,
) -> None:
if config.num_gpus > 0 and not GPU_INSTALLED:
raise RuntimeError(
"GPU support requires C++ installation: "
"install with C++ support by running "
"`PBG_INSTALL_CPP=1 pip install .`"
)
CoordinatorT = (
GPUTrainingCoordinator if config.num_gpus > 0 else TrainingCoordinator
)
Expand Down
15 changes: 14 additions & 1 deletion torchbiggraph/train_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import torch
import torch.multiprocessing as mp
from torchbiggraph import _C
from torchbiggraph.batching import AbstractBatchProcessor, process_in_batches
from torchbiggraph.config import ConfigFileLoader, ConfigSchema, add_to_sys_path
from torchbiggraph.edgelist import EdgeList
Expand Down Expand Up @@ -61,6 +60,14 @@
)


try:
from torchbiggraph import _C

CPP_INSTALLED = True
except ImportError:
CPP_INSTALLED = False


logger = logging.getLogger("torchbiggraph")
dist_logger = logging.LoggerAdapter(logger, {"distributed": True})

Expand Down Expand Up @@ -423,6 +430,12 @@ def __init__(
)

assert config.num_gpus > 0
if not CPP_INSTALLED:
raise RuntimeError(
"GPU support requires C++ installation: "
"install with C++ support by running "
"`PBG_INSTALL_CPP=1 pip install .`"
)

if config.half_precision:
for entity in config.entities:
Expand Down

0 comments on commit 197e9c4

Please sign in to comment.