Skip to content

Commit

Permalink
Offer a way to do user-provided initialization in every subprocess
Browse files Browse the repository at this point in the history
Summary: Essentially just a lot of plumbing, which for now is left unused.

Reviewed By: adamlerer

Differential Revision: D16581941

fbshipit-source-id: 68324c85528aabca2a890eb6536c650cc0fb168a
  • Loading branch information
lw authored and facebook-github-bot committed Aug 2, 2019
1 parent e8da946 commit 4eb3f25
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 17 deletions.
8 changes: 6 additions & 2 deletions torchbiggraph/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from abc import ABC, abstractmethod
from datetime import timedelta
from typing import List, NamedTuple, Optional
from typing import Callable, List, NamedTuple, Optional

import torch.distributed as td
import torch.multiprocessing as mp
Expand Down Expand Up @@ -103,7 +103,10 @@ def _server_init(
world_size: int,
server_rank: Rank,
groups: List[List[Rank]],
subprocess_init: Optional[Callable[[], None]] = None,
) -> None:
if subprocess_init is not None:
subprocess_init()
init_process_group(
init_method=init_method,
world_size=world_size,
Expand All @@ -119,11 +122,12 @@ def start_server(
world_size: int,
server_rank: Rank,
groups: List[List[Rank]],
subprocess_init: Optional[Callable[[], None]] = None,
) -> mp.Process:
p = mp.get_context("spawn").Process(
name="%s-%d" % (type(server).__name__, server_rank),
target=_server_init,
args=(server, init_method, world_size, server_rank, groups),
args=(server, init_method, world_size, server_rank, groups, subprocess_init),
)
p.daemon = True
p.start()
Expand Down
8 changes: 5 additions & 3 deletions torchbiggraph/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from functools import partial
from itertools import chain
from typing import Generator, List, Optional, Tuple
from typing import Callable, Generator, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -78,6 +78,7 @@ def do_eval_and_report_stats(
config: ConfigSchema,
model: Optional[MultiRelationEmbedder] = None,
evaluator: Optional[AbstractBatchProcessor] = None,
subprocess_init: Optional[Callable[[], None]] = None,
) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]:
"""Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained
embeddings.
Expand All @@ -101,7 +102,7 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:
nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)

num_workers = get_num_workers(config.workers)
pool = create_pool(num_workers)
pool = create_pool(num_workers, subprocess_init=subprocess_init)

if model is None:
model = make_model(config)
Expand Down Expand Up @@ -200,9 +201,10 @@ def do_eval(
config: ConfigSchema,
model: Optional[MultiRelationEmbedder] = None,
evaluator: Optional[AbstractBatchProcessor] = None,
subprocess_init: Optional[Callable[[], None]] = None,
) -> None:
# Create and run the generator until exhaustion.
for _ in do_eval_and_report_stats(config, model, evaluator):
for _ in do_eval_and_report_stats(config, model, evaluator, subprocess_init):
pass


Expand Down
5 changes: 3 additions & 2 deletions torchbiggraph/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import re
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

import h5py
import numpy as np
Expand Down Expand Up @@ -477,6 +477,7 @@ def __init__(
num_machines: int = 1,
background: bool = False,
partition_client: Optional[PartitionClient] = None,
subprocess_init: Optional[Callable[[], None]] = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -511,7 +512,7 @@ def __init__(

self.background: bool = background
if self.background:
self.pool: mp.Pool = create_pool(1)
self.pool: mp.Pool = create_pool(1, subprocess_init=subprocess_init)
# FIXME In py-3.7 switch to typing.OrderedDict[str, AsyncResult].
self.outstanding: OrderedDict = OrderedDict()
self.prefetched: Dict[str, Tuple[FloatTensorType, Optional[OptimizerStateDict]]] = {}
Expand Down
17 changes: 15 additions & 2 deletions torchbiggraph/parameter_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import queue
import time
import traceback
from typing import Dict, List, Optional, Set
from typing import Callable, Dict, List, Optional, Set

import torch
import torch.distributed as td
Expand Down Expand Up @@ -323,10 +323,13 @@ def _client_thread_loop(
init_method: Optional[str],
world_size: int,
groups: List[List[Rank]],
subprocess_init: Optional[Callable[[], None]] = None,
max_bandwidth: float = 1e8,
min_sleep_time: float = 0.01,
) -> None:
try:
if subprocess_init is not None:
subprocess_init()
init_process_group(
rank=client_rank,
init_method=init_method,
Expand Down Expand Up @@ -402,13 +405,23 @@ def __init__(
init_method: Optional[str],
world_size: int,
groups: List[List[Rank]],
subprocess_init: Optional[Callable[[], None]] = None,
) -> None:
self.q = mp.get_context("spawn").Queue()
self.errq = mp.get_context("spawn").Queue()
self.p = mp.get_context("spawn").Process(
name="ParameterClient-%d" % client_rank,
target=_client_thread_loop,
args=(client_rank, all_server_ranks, self.q, self.errq, init_method, world_size, groups)
args=(
client_rank,
all_server_ranks,
self.q,
self.errq,
init_method,
world_size,
groups,
subprocess_init,
),
)
self.p.daemon = True
self.p.start()
Expand Down
13 changes: 12 additions & 1 deletion torchbiggraph/partitionserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,28 @@

import argparse
from itertools import chain
from typing import Callable, Optional

import torch.distributed as td

from torchbiggraph.config import ConfigSchema, parse_config
from torchbiggraph.distributed import ProcessRanks, init_process_group
from torchbiggraph.parameter_sharing import ParameterServer
from torchbiggraph.types import Rank

# This is a small binary that just runs a partition server.
# You need to run this if you run a distributed run and set
# num_partition_servers > 1.


def run_partition_server(config, rank=0):
RANK_ZERO = Rank(0)


def run_partition_server(
config: ConfigSchema,
rank: Rank = RANK_ZERO,
subprocess_init: Optional[Callable[[], None]] = None,
) -> None:
if config.num_partition_servers <= 0:
raise RuntimeError("Config doesn't require explicit partition servers")
if not 0 <= rank < config.num_partition_servers:
Expand All @@ -30,6 +39,8 @@ def run_partition_server(config, rank=0):
"distributed training capabilities.")
ranks = ProcessRanks.from_num_invocations(
config.num_machines, config.num_partition_servers)
if subprocess_init is not None:
subprocess_init()
init_process_group(
rank=ranks.partition_servers[rank],
world_size=ranks.world_size,
Expand Down
13 changes: 10 additions & 3 deletions torchbiggraph/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from abc import ABC, abstractmethod
from functools import partial
from itertools import chain
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple

import torch
import torch.distributed as td
Expand Down Expand Up @@ -256,6 +256,7 @@ def train_and_report_stats(
trainer: Optional[AbstractBatchProcessor] = None,
evaluator: Optional[AbstractBatchProcessor] = None,
rank: Rank = RANK_ZERO,
subprocess_init: Optional[Callable[[], None]] = None,
) -> Generator[Tuple[int, Optional[Stats], Stats, Optional[Stats]], None, None]:
"""Each epoch/pass, for each partition pair, loads in embeddings and edgelist
from disk, runs HOGWILD training on them, and writes partitions back to disk.
Expand Down Expand Up @@ -310,6 +311,7 @@ def train_and_report_stats(
world_size=ranks.world_size,
init_method=config.distributed_init_method,
groups=[ranks.trainers],
subprocess_init=subprocess_init,
)

bucket_scheduler = DistributedBucketScheduler(
Expand All @@ -324,6 +326,7 @@ def train_and_report_stats(
init_method=config.distributed_init_method,
world_size=ranks.world_size,
groups=[ranks.trainers],
subprocess_init=subprocess_init,
)

parameter_sharer = ParameterSharer(
Expand All @@ -332,6 +335,7 @@ def train_and_report_stats(
init_method=config.distributed_init_method,
world_size=ranks.world_size,
groups=[ranks.trainers],
subprocess_init=subprocess_init,
)

if config.num_partition_servers == -1:
Expand All @@ -341,6 +345,7 @@ def train_and_report_stats(
world_size=ranks.world_size,
init_method=config.distributed_init_method,
groups=[ranks.trainers],
subprocess_init=subprocess_init,
)

if len(ranks.partition_servers) > 0:
Expand Down Expand Up @@ -369,7 +374,7 @@ def train_and_report_stats(
# fork early for HOGWILD threads
log("Creating workers...")
num_workers = get_num_workers(config.workers)
pool = create_pool(num_workers)
pool = create_pool(num_workers, subprocess_init=subprocess_init)

def make_optimizer(params: Iterable[torch.nn.Parameter], is_emb: bool) -> Optimizer:
params = list(params)
Expand All @@ -395,6 +400,7 @@ def make_optimizer(params: Iterable[torch.nn.Parameter], is_emb: bool) -> Optimi
rank=rank,
num_machines=config.num_machines,
partition_client=partition_client,
subprocess_init=subprocess_init,
)
checkpoint_manager.register_metadata_provider(ConfigMetadataProvider(config))
checkpoint_manager.write_config(config)
Expand Down Expand Up @@ -790,9 +796,10 @@ def train(
trainer: Optional[AbstractBatchProcessor] = None,
evaluator: Optional[AbstractBatchProcessor] = None,
rank: Rank = RANK_ZERO,
subprocess_init: Optional[Callable[[], None]] = None,
) -> None:
# Create and run the generator until exhaustion.
for _ in train_and_report_stats(config, model, trainer, evaluator, rank):
for _ in train_and_report_stats(config, model, trainer, evaluator, rank, subprocess_init):
pass


Expand Down
15 changes: 11 additions & 4 deletions torchbiggraph/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os.path
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, Iterable, Optional, Set, Tuple
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple

import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -93,12 +93,17 @@ def share_memory(self) -> None:

# HOGWILD

def _pool_init():
def _pool_init(subprocess_init: Optional[Callable[[], None]] = None) -> None:
torch.set_num_threads(1)
torch.manual_seed(os.getpid())
if subprocess_init is not None:
subprocess_init()


def create_pool(num_workers: int) -> mp.Pool:
def create_pool(
num_workers: int,
subprocess_init: Optional[Callable[[], None]] = None,
) -> mp.Pool:
# PyTorch relies on OpenMP, which by default parallelizes operations by
# implicitly spawning as many threads as there are cores, and synchronizing
# them with each other. This interacts poorly with Hogwild!-style subprocess
Expand All @@ -112,7 +117,9 @@ def create_pool(num_workers: int) -> mp.Pool:
# https://github.com/pytorch/pytorch/issues/17199 for some more information
# and discussion.
torch.set_num_threads(1)
return mp.get_context("spawn").Pool(num_workers, initializer=_pool_init)
return mp.get_context("spawn").Pool(
num_workers, initializer=_pool_init, initargs=(subprocess_init,)
)


# config routines
Expand Down

0 comments on commit 4eb3f25

Please sign in to comment.