Skip to content

Commit

Permalink
Only import torch.distributed when needed (open-mmlab#882)
Browse files Browse the repository at this point in the history
* Fix an import error for `get_world_size` and `get_rank`

* Only import torch.distributed when needed

torch.distributed is only used in DistributedGroupSampler

* use `get_dist_info` to obtain world size and rank

`get_dist_info` from `mmcv.runner.utils` handles the problem of `distributed_c10d` doesn't exist.
  • Loading branch information
zhangtemplar authored and hellock committed Jun 27, 2019
1 parent f080ccb commit 8bf38df
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions mmdet/datasets/loader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import numpy as np

from torch.distributed import get_world_size, get_rank
from mmcv.runner.utils import get_dist_info
from torch.utils.data import Sampler
from torch.utils.data import DistributedSampler as _DistributedSampler

Expand Down Expand Up @@ -95,10 +95,11 @@ def __init__(self,
samples_per_gpu=1,
num_replicas=None,
rank=None):
_rank, _num_replicas = get_dist_info()
if num_replicas is None:
num_replicas = get_world_size()
num_replicas = _num_replicas
if rank is None:
rank = get_rank()
rank = _rank
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.num_replicas = num_replicas
Expand Down

0 comments on commit 8bf38df

Please sign in to comment.