Skip to content

Commit

Permalink
Prevent mixing model parallel and multiprocessing. (facebookresearch#…
Browse files Browse the repository at this point in the history
…2964)

* Prevent mixing model parallel and multiprocessing.

* Lint.
  • Loading branch information
stephenroller authored Aug 25, 2020
1 parent 5987ff1 commit e478d13
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 3 deletions.
4 changes: 3 additions & 1 deletion parlai/core/torch_classifier_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,9 @@ def __init__(self, opt: Opt, shared=None):
self.load(init_model)
if self.use_cuda:
if self.model_parallel:
self.model = PipelineHelper().make_parallel(self.model)
ph = PipelineHelper()
ph.check_compatibility(self.opt)
self.model = ph.make_parallel(self.model)
else:
self.model.cuda()
if self.data_parallel:
Expand Down
4 changes: 3 additions & 1 deletion parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,9 @@ def __init__(self, opt: Opt, shared=None):
)
if self.use_cuda:
if self.model_parallel:
self.model = PipelineHelper().make_parallel(self.model)
ph = PipelineHelper()
ph.check_compatibility(self.opt)
self.model = ph.make_parallel(self.model)
else:
self.model.cuda()
self.criterion.cuda()
Expand Down
4 changes: 3 additions & 1 deletion parlai/core/torch_ranker_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def __init__(self, opt: Opt, shared=None):

if self.use_cuda:
if self.model_parallel:
self.model = PipelineHelper().make_parallel(self.model)
ph = PipelineHelper()
ph.check_compatibility(self.opt)
self.model = ph.make_parallel(self.model)
else:
self.model.cuda()
if self.data_parallel:
Expand Down
1 change: 1 addition & 0 deletions parlai/scripts/multiprocessing_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def multiprocess_eval(
with distributed_utils.distributed_context(
rank, opt, port, rank_offset, gpu, hostname
) as opt:
opt['multiprocessing'] = True
return eval_model.eval_model(opt)


Expand Down
1 change: 1 addition & 0 deletions parlai/scripts/multiprocessing_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def multiprocess_train(
rank, opt, port, rank_offset, gpu, hostname
) as opt:
# Run the actual training
opt['multiprocessing'] = True
return single_train.TrainLoop(opt).train()


Expand Down
18 changes: 18 additions & 0 deletions parlai/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,24 @@ def __init__(self):
self.devices.append(d)
self.__device_allocations[d] = 0

def check_compatibility(self, opt):
"""
Check compatibility for opts.
Really just used to raise an error message if the user mixes multiprocessing and
model parallelism.
"""
if opt.get('multiprocessing') and not os.environ.get('PARLAI_FORCE_MP'):
raise RuntimeError(
"It looks like you are trying to mix multiprocessing data "
"parallelism (multiprocessing_train or multiprocessing_eval) "
"with --model-parallel true. This is almost certainly a user "
"error, and is going to result in hanging as the two methods "
"fight for resources. Use simple `train_model` instead of "
"`mp_train`, or add `--model-parallel false`. For more info, "
"see https://github.com/facebookresearch/ParlAI/issues/2962."
)

def make_parallel(self, model: torch.nn.Module) -> torch.nn.Module:
"""
Allocate specific layers in a model to be ModelParallel.
Expand Down
17 changes: 17 additions & 0 deletions tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,23 @@ def test_chunked_teacher(self):
assert valid['exs'].value() == 100
assert test['exs'].value() == 100

def test_no_model_parallel(self):
"""
Checks that we throw an error when combining mp_train with.
--model-parallel true.
"""
config = copy.deepcopy(self._base_config)
config['model_parallel'] = True
for m in [
'transformer/generator',
'transformer/ranker',
'transformer/classifier',
]:
config['model'] = m
with self.assertRaises(RuntimeError):
_ = self._distributed_train_model(config)


if __name__ == '__main__':
unittest.main()

0 comments on commit e478d13

Please sign in to comment.