Skip to content

Commit

Permalink
Fix GPU functional tests on FB
Browse files Browse the repository at this point in the history
Summary:
- GPU tests don't work on devgpu due to the infamous cudart issue. This workaround gets it working. Shrug.
- `buck test` is annoying because it eats stdout/stderr from the child processes. Instead I want to run the tests directly from python, but I need an ifbpy with a `torchbiggraph` dependency. So added this target.

Reviewed By: lw

Differential Revision: D29560732

fbshipit-source-id: 4198b2636896327f6d6b67477db731f760bde135
  • Loading branch information
adamlerer authored and facebook-github-bot committed Jul 30, 2021
1 parent 24bd54e commit c51af52
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
11 changes: 11 additions & 0 deletions ifbpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env python3
import warnings

from libfb.py.ipython_par import launch_ipython


warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.simplefilter("ignore", category=ResourceWarning)


launch_ipython()
13 changes: 6 additions & 7 deletions torchbiggraph/train_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
) -> None:
super().__init__(daemon=True, name=f"GPU #{gpu_idx}")
self.gpu_idx = gpu_idx

self.master_endpoint, self.worker_endpoint = mp.get_context("spawn").Pipe()
self.subprocess_init = subprocess_init
self.sub_holder: Dict[
Expand All @@ -143,20 +144,17 @@ def my_device(self) -> torch.device:
return torch.device("cuda", index=self.gpu_idx)

def run(self) -> None:

torch.set_num_threads(1)
torch.cuda.set_device(self.my_device)
if self.subprocess_init is not None:
self.subprocess_init()
self.master_endpoint.close()

for s in self.embedding_storage_freelist:
assert s.is_shared()
cptr = ctypes.c_void_p(s.data_ptr())
csize = ctypes.c_size_t(s.size() * s.element_size())
cflags = ctypes.c_uint(0)
# FIXME: broken by D20249187
# cudart = torch.cuda.cudart()
cudart = ctypes.cdll.LoadLibrary(ctypes.util.find_library("cudart"))
res = cudart.cudaHostRegister(cptr, csize, cflags)
cudart = torch.cuda.cudart()
res = cudart.cudaHostRegister(s.data_ptr(), s.size() * s.element_size(), 0)
torch.cuda.check_error(res)
assert s.is_pinned()
logger.info(f"GPU subprocess {self.gpu_idx} up and running")
Expand Down Expand Up @@ -615,6 +613,7 @@ def schedule(gpu_idx: GPURank) -> None:

for gpu_idx in range(self.gpu_pool.num_gpus):
schedule(gpu_idx)

while busy_gpus:
gpu_idx, result = self.gpu_pool.wait_for_next()
assert gpu_idx == result.gpu_idx
Expand Down

0 comments on commit c51af52

Please sign in to comment.