Skip to content

Commit

Permalink
modify world_size, rank
Browse files Browse the repository at this point in the history
  • Loading branch information
tree-park committed Oct 31, 2022
1 parent 8331315 commit cb66aad
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 30 deletions.
25 changes: 13 additions & 12 deletions oslo/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
PipelineParallel,
TensorParallel,
)
from oslo.torch.distributed.parallel_mode import ParallelMode
from oslo.torch.nn.parallel.sequence_parallel import SequenceParallel
from oslo.torch.nn.parallel.data_parallel.data_parallel import DataParallel
from oslo.torch.nn.parallel.data_parallel._ddp.distributed_data_parallel import (
Expand Down Expand Up @@ -805,22 +806,22 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
# seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed

if (
self.args.world_size > 1
self.parallel_context.get_local_rank(ParallelMode.DATA) > 1
and self.args.oslo_config.data_parallelism is not None
):
if not self.args.dataloader_drop_last:
return DistributedSamplerWithLoop(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
num_replicas=self.parallel_context.get_local_rank(ParallelMode.DATA),
rank=self.parallel_context.get_local_rank(ParallelMode.DATA),
# seed=seed, TODO oslo seed
)
else:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
num_replicas=self.parallel_context.get_local_rank(ParallelMode.DATA),
rank=self.parallel_context.get_local_rank(ParallelMode.DATA),
# seed=seed, TODO oslo seed
)
else:
Expand All @@ -843,21 +844,21 @@ def get_train_dataloader(self) -> DataLoader:
# if isinstance(train_dataset, datasets.Dataset):
# train_dataset = self._remove_unused_columns(train_dataset, description="training")
log_dist(f"Collate_fn: {self.data_collator.__class__}")
if self.args.dataloader_num_workers % self.args.world_size != 0:
if self.args.dataloader_num_workers % self.parallel_context.get_local_rank(ParallelMode.DATA) != 0:
raise ValueError("dataloader_num_workers should be dividable by world_size")
num_workers = self.args.dataloader_num_workers / self.args.world_size
num_workers = self.args.dataloader_num_workers / self.parallel_context.get_local_rank(ParallelMode.DATA)

if isinstance(train_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
if self.parallel_context.get_local_rank(ParallelMode.DATA) > 1:
train_dataset = IterableDatasetShard(
train_dataset,
batch_size=self.args.train_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size,
process_index=self.args.process_index,
num_processes=self.parallel_context.get_local_rank(ParallelMode.DATA),
process_index=self.parallel_context.get_local_rank(ParallelMode.DATA),
)
log_dist(
f"Dataset: {train_dataset.__class__} with\nbatch_size:{self.args.train_batch_size}\n world_size:{self.args.world_size}\n dataloader_drop_last: {self.args.dataloader_drop_last}"
f"Dataset: {train_dataset.__class__} with\nbatch_size:{self.args.train_batch_size}\n world_size:{self.parallel_context.get_local_rank(ParallelMode.DATA)}\n dataloader_drop_last: {self.args.dataloader_drop_last}"
)
return DataLoader(
train_dataset,
Expand All @@ -867,7 +868,7 @@ def get_train_dataloader(self) -> DataLoader:
)
train_sampler = self._get_train_sampler()
log_dist(
f"Sampler: {train_sampler.__class__} with\nbatch_size:{self.args.train_batch_size}\nworld_size:{self.args.world_size}, dataloader_drop_last: {self.args.dataloader_drop_last}"
f"Sampler: {train_sampler.__class__} with\nbatch_size:{self.args.train_batch_size}\nworld_size:{self.parallel_context.get_local_rank(ParallelMode.DATA)}, dataloader_drop_last: {self.args.dataloader_drop_last}"
)
return DataLoader(
train_dataset,
Expand Down
32 changes: 14 additions & 18 deletions oslo/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from transformers.trainer_utils import SchedulerType, IntervalStrategy
from oslo.transformers.trainer_utils import OptimizerNames
from oslo.torch.distributed.parallel_mode import ParallelMode


@dataclass
Expand Down Expand Up @@ -294,29 +293,26 @@ def local_rank(self):
"""
The local rank
"""
if self.parallel_context:
return self.parallel_context.get_local_rank(ParallelMode.DATA)
else:
return int(os.environ.get("LOCAL_RANK", -1))

@property
def world_size(self):
"""
The number of processes used in parallel.
"""
if self.parallel_context:
return self.parallel_context.get_world_size(ParallelMode.DATA)
elif self.local_rank != -1:
return int(os.environ["WORLD_SIZE"])
return 1

return int(os.environ.get("LOCAL_RANK", -1))
#
# @property
# def world_size(self):
# """
# The number of processes used in parallel.
# """
# if self.parallel_context:
# return self.parallel_context.get_world_size(ParallelMode.DATA)
# elif self.local_rank != -1:
# return int(os.environ["WORLD_SIZE"])
# return 1
#
@property
def process_index(self):
"""
The index of the current process used.
"""
if self.local_rank != -1:
return self.local_rank
return torch.distributed.get_rank()
return 0

@property
Expand Down

0 comments on commit cb66aad

Please sign in to comment.