Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
tree-park committed Nov 27, 2022
1 parent db85685 commit 316daf0
Showing 1 changed file with 73 additions and 57 deletions.
130 changes: 73 additions & 57 deletions oslo/transformers/oslo_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from oslo.torch.nn.parallel.sequence_parallel import SequenceParallel
from oslo.torch.nn.parallel.data_parallel.data_parallel import DataParallel
from oslo.torch.nn.parallel.data_parallel._ddp.distributed_data_parallel import (
DistributedDataParallel,)
DistributedDataParallel,
)
from oslo.torch.distributed.parallel_mode import ParallelMode
from .trainer_utils import log_dist

Expand All @@ -28,10 +29,8 @@ def _type(_dtype):

def _values(*args):
return lambda key, val: {
"check":
val in args,
"msg":
f"{key}: {val} is not a valid set. it must be one of {list(args)}",
"check": val in args,
"msg": f"{key}: {val} is not a valid set. it must be one of {list(args)}",
}


Expand All @@ -50,11 +49,7 @@ class SupportedBackend(Enum):


SUPPORTED_FEATURES = {
"backend": {
"name": str,
"host": str,
"port": str
},
"backend": {"name": str, "host": str, "port": str},
"mixed_precision": {
"enable": _type(bool),
},
Expand Down Expand Up @@ -115,17 +110,21 @@ def _config_check(arg, user_config):
if isinstance(arg, dict):
assert k in arg, (
f"An argument ``{k}`` is not available. "
f"We only support the arguments like {list(arg.keys())}.")
f"We only support the arguments like {list(arg.keys())}."
)
else:
raise Exception(f"``{k}: {user_config[k]} is not a valid set. "
f"please check your configuration.``")
raise Exception(
f"``{k}: {user_config[k]} is not a valid set. "
f"please check your configuration.``"
)

if isinstance(user_config[k], dict):
_config_check(arg[k], user_config[k])
else:
assert not isinstance(arg[k], dict), (
f"``{k}: {user_config[k]} is not a valid set. "
f"please check your configuration.``")
f"please check your configuration.``"
)
check_result = arg[k](k, user_config[k])
assert check_result["check"], check_result["msg"]
else:
Expand Down Expand Up @@ -222,104 +221,111 @@ def __init__(self, config_file_or_dict):
with open(config_file_or_dict, "r", encoding="utf-8") as f:
cfg = json.load(f)
else:
raise ValueError(
"Expecting either a path to a oslo config file or a dict")
raise ValueError("Expecting either a path to a oslo config file or a dict")
_config_check(SUPPORTED_FEATURES, cfg)

log_dist("*** OSLO CONFIG ***")

if 'backend' not in cfg:
if "backend" not in cfg:
self.backend = SupportedBackend.TORCH
elif cfg['backend'] in SupportedBackend:
self.backend = SupportedBackend[cfg['backend']]
elif cfg["backend"] in SupportedBackend:
self.backend = SupportedBackend[cfg["backend"]]
if self.backend in [SupportedBackend.OPENMPI]:
if 'host' in cfg['backend']:
self.host = cfg['backend']['host']
if "host" in cfg["backend"]:
self.host = cfg["backend"]["host"]
log_dist(f"host: {self.host}")
else:
log_dist(f"host is required to use {self.backend}")
if 'port' in cfg['backend']:
self.port = cfg['backend']['port']
if "port" in cfg["backend"]:
self.port = cfg["backend"]["port"]
log_dist(f"host: {self.host}")
else:
ValueError(f"post is required to use {self.backend}")
log_dist(f"backend engine: {self.backend}")

if 'mixed_precision' in cfg and cfg['mixed_precision']['enable'] is True:
if "mixed_precision" in cfg and cfg["mixed_precision"]["enable"] is True:
self.mixed_precision = True
log_dist("mixed_precision: enabled")

if 'data_parallelism' in cfg and cfg['data_parallelism']['enable'] is True:
if cfg['data_parallelism']["parallel_size"] is None:
if "data_parallelism" in cfg and cfg["data_parallelism"]["enable"] is True:
if cfg["data_parallelism"]["parallel_size"] is None:
log_dist(
"data_parallelism can not be usable because parallel_size is required.",
logging.WARNING,
)
elif cfg['data_parallelism']["zero_stage"] is None:
elif cfg["data_parallelism"]["zero_stage"] is None:
logging.warning(
"data_parallelism can not be usable because zero_stage is required."
)
else:
if ('params' in cfg['data_parallelism'] and
cfg['data_parallelism']['params']['cpu_offload']):
if (
"params" in cfg["data_parallelism"]
and cfg["data_parallelism"]["params"]["cpu_offload"]
):
self.cpu_offload = True
self.data_parallelism = cfg['data_parallelism']
self.data_parallelism = cfg["data_parallelism"]
log_dist(
f"data_parallelism: enabled"
f"\tparallel_size: {self.data_parallelism['parallel_size']}"
f"\tzero_stage: {self.data_parallelism['zero_stage']}"
f"\tcpu_offload: {self.cpu_offload}"
)

if 'sequence_parallelism' in cfg and cfg['sequence_parallelism']['enable'] is True:
if cfg['sequence_parallelism']["parallel_size"] is None:
if (
"sequence_parallelism" in cfg
and cfg["sequence_parallelism"]["enable"] is True
):
if cfg["sequence_parallelism"]["parallel_size"] is None:
log_dist(
"sequence_parallelism can not be usable because parallel_size is required.",
logging.WARNING,
)
else:
self.sequence_parallelism = cfg['sequence_parallelism']
self.sequence_parallelism = cfg["sequence_parallelism"]
log_dist(
f"sequence_parallelism: enabled\n\tparallel_size: {self.sequence_parallelism['parallel_size']}"
)

if 'tensor_parallelism' in cfg and cfg['tensor_parallelism']['enable'] is True:
if cfg['tensor_parallelism']["parallel_size"] is None:
if "tensor_parallelism" in cfg and cfg["tensor_parallelism"]["enable"] is True:
if cfg["tensor_parallelism"]["parallel_size"] is None:
ValueError(
"tensor_parallelism can not be usable because parallel_size is required."
)
elif cfg['tensor_parallelism']["parallel_mode"] is None:
elif cfg["tensor_parallelism"]["parallel_mode"] is None:
log_dist(
"tensor_parallelism can not be usable because parallel_mode is required.",
logging.WARNING,
)
else:
self.tensor_parallelism = cfg['tensor_parallelism']
self.tensor_parallelism = cfg["tensor_parallelism"]
log_dist(
f"tensor_parallelism: enabled\n\tparallel_size: {self.tensor_parallelism['parallel_size']}\n\tparallel_mode: {self.tensor_parallelism['parallel_mode']}"
)

if 'pipeline_parallelism' in cfg and cfg['pipeline_parallelism']['enable'] is True:
if cfg['pipeline_parallelism']["parallel_size"] is None:
if (
"pipeline_parallelism" in cfg
and cfg["pipeline_parallelism"]["enable"] is True
):
if cfg["pipeline_parallelism"]["parallel_size"] is None:
log_dist(
"pipeline_parallelism can not be usable because parallel_size is required.",
logging.WARNING,
)
self.pipeline_parallelism = None
else:
self.pipeline_parallelism = cfg['pipeline_parallelism']
self.pipeline_parallelism = cfg["pipeline_parallelism"]
log_dist(
f"pipeline_parallelism: enabled\n\tparallel_size: {self.pipeline_parallelism['parallel_size']}"
)

if 'expert_parallelism' in cfg and cfg['expert_parallelism']['enable'] is True:
if cfg['expert_parallelism']["parallel_size"] is None:
if "expert_parallelism" in cfg and cfg["expert_parallelism"]["enable"] is True:
if cfg["expert_parallelism"]["parallel_size"] is None:
log_dist(
"expert_parallelism can not be usable because parallel_size is required.",
logging.WARNING,
)
else:
self.expert_parallelism = cfg['expert_parallelism']
self.expert_parallelism = cfg["expert_parallelism"]
log_dist(
f"expert_parallelism: enabled\n\tparallel_size: {self.expert_parallelism['parallel_size']}"
)
Expand All @@ -342,7 +348,8 @@ def __repr__(self):


def init_oslo_features(
oslo_init_config: OsloTrainerConfig,) -> Tuple[ParallelContext, List]:
oslo_init_config: OsloTrainerConfig,
) -> Tuple[ParallelContext, List]:
"""
Init OSLO features with json or dict configuration user passed.
ParallelContext or other effective features should be defined on this function
Expand All @@ -356,26 +363,35 @@ def init_oslo_features(
>> allocate_params(wrapper_model, parallel_context)
"""
cfg = oslo_init_config
data_parallel_size = (cfg.data_parallelism['parallel_size']
if cfg.data_parallelism else 1)
sequence_parallel_size = (cfg.sequence_parallelism['parallel_size']
if cfg.sequence_parallelism else 1)
expert_parallel_size = (cfg.expert_parallelism['parallel_size']
if cfg.expert_parallelism else 1)
pipeline_parallel_size = (cfg.pipeline_parallelism['parallel_size']
if cfg.pipeline_parallelism else 1)
data_parallel_size = (
cfg.data_parallelism["parallel_size"] if cfg.data_parallelism else 1
)
sequence_parallel_size = (
cfg.sequence_parallelism["parallel_size"] if cfg.sequence_parallelism else 1
)
expert_parallel_size = (
cfg.expert_parallelism["parallel_size"] if cfg.expert_parallelism else 1
)
pipeline_parallel_size = (
cfg.pipeline_parallelism["parallel_size"] if cfg.pipeline_parallelism else 1
)
tensor_parallel_size, tensor_parallel_depth, tensor_parallel_mode = (
1,
1,
TENSOR_PARALLEL_MAPPING["1d"],
)
if cfg.tensor_parallelism:
tensor_parallel_size = cfg.tensor_parallelism['parallel_size']
tensor_parallel_size = cfg.tensor_parallelism["parallel_size"]
tensor_parallel_mode = TENSOR_PARALLEL_MAPPING[
cfg.tensor_parallelism['parallel_mode']]
if 'param' in cfg.tensor_parallelism and 'parallel_depth_2.5d' in cfg.tensor_parallelism['param']:
tensor_parallel_depth = cfg.tensor_parallelism['param'][
"parallel_depth_2.5d"]
cfg.tensor_parallelism["parallel_mode"]
]
if (
"param" in cfg.tensor_parallelism
and "parallel_depth_2.5d" in cfg.tensor_parallelism["param"]
):
tensor_parallel_depth = cfg.tensor_parallelism["param"][
"parallel_depth_2.5d"
]

if cfg.backend == SupportedBackend.TORCH:
parallel_context = ParallelContext.from_torch(
Expand Down

0 comments on commit 316daf0

Please sign in to comment.