Skip to content

Commit

Permalink
Add save and eval features on trainer (#132)
Browse files Browse the repository at this point in the history
## Title

- Add save and eval features on trainer 

## Description

- Add evaluation & test is done
- Add save & test is done
- Modify oslo_init to raise Valueerror when default config is empty (ex.
parallel_mode)
  • Loading branch information
tree-park authored Feb 6, 2023
1 parent 9fc5067 commit 7db1d1b
Show file tree
Hide file tree
Showing 6 changed files with 867 additions and 58 deletions.
46 changes: 25 additions & 21 deletions oslo/transformers/oslo_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,12 @@ def __init__(self, config_file_or_dict):
self.host = cfg["backend"]["host"]
log_dist(f"host: {self.host}")
else:
log_dist(f"host is required to use {self.backend}")
raise ValueError(f"host is required to use {self.backend}")
if "port" in cfg["backend"]:
self.port = cfg["backend"]["port"]
log_dist(f"host: {self.host}")
else:
ValueError(f"post is required to use {self.backend}")
raise ValueError(f"post is required to use {self.backend}")
log_dist(f"backend engine: {self.backend}")

if "mixed_precision" in cfg and cfg["mixed_precision"]["enable"] is True:
Expand All @@ -243,13 +243,12 @@ def __init__(self, config_file_or_dict):

if "data_parallelism" in cfg and cfg["data_parallelism"]["enable"] is True:
if cfg["data_parallelism"]["parallel_size"] is None:
log_dist(
"data_parallelism can not be usable because parallel_size is required.",
logging.WARNING,
raise ValueError(
f"data_parallelism can not be usable because parallel_size is required."
)
elif cfg["data_parallelism"]["zero_stage"] is None:
logging.warning(
"data_parallelism can not be usable because zero_stage is required."
raise ValueError(
f"data_parallelism can not be usable because zero_stage is required."
)
else:
if (
Expand All @@ -258,6 +257,8 @@ def __init__(self, config_file_or_dict):
):
self.cpu_offload = True
self.data_parallelism = cfg["data_parallelism"]
if "params" not in self.data_parallelism:
self.data_parallelism["params"] = {}
log_dist(
f"data_parallelism: enabled"
f"\tparallel_size: {self.data_parallelism['parallel_size']}"
Expand All @@ -270,28 +271,30 @@ def __init__(self, config_file_or_dict):
and cfg["sequence_parallelism"]["enable"] is True
):
if cfg["sequence_parallelism"]["parallel_size"] is None:
log_dist(
"sequence_parallelism can not be usable because parallel_size is required.",
logging.WARNING,
raise ValueError(
f"sequence_parallelism can not be usable because parallel_size is required."
)
else:
self.sequence_parallelism = cfg["sequence_parallelism"]
if "params" not in self.sequence_parallelism:
self.sequence_parallelism["params"] = {}
log_dist(
f"sequence_parallelism: enabled\n\tparallel_size: {self.sequence_parallelism['parallel_size']}"
)

if "tensor_parallelism" in cfg and cfg["tensor_parallelism"]["enable"] is True:
if cfg["tensor_parallelism"]["parallel_size"] is None:
ValueError(
raise ValueError(
"tensor_parallelism can not be usable because parallel_size is required."
)
elif cfg["tensor_parallelism"]["parallel_mode"] is None:
log_dist(
"tensor_parallelism can not be usable because parallel_mode is required.",
logging.WARNING,
raise ValueError(
"tensor_parallelism can not be usable because parallel_mode is required."
)
else:
self.tensor_parallelism = cfg["tensor_parallelism"]
if "params" not in self.tensor_parallelism:
self.tensor_parallelism["params"] = {}
log_dist(
f"tensor_parallelism: enabled\n\tparallel_size: {self.tensor_parallelism['parallel_size']}\n\tparallel_mode: {self.tensor_parallelism['parallel_mode']}"
)
Expand All @@ -301,25 +304,26 @@ def __init__(self, config_file_or_dict):
and cfg["pipeline_parallelism"]["enable"] is True
):
if cfg["pipeline_parallelism"]["parallel_size"] is None:
log_dist(
"pipeline_parallelism can not be usable because parallel_size is required.",
logging.WARNING,
raise ValueError(
"pipeline_parallelism can not be usable because parallel_size is required."
)
self.pipeline_parallelism = None
else:
self.pipeline_parallelism = cfg["pipeline_parallelism"]
if "params" not in self.pipeline_parallelism:
self.pipeline_parallelism["params"] = {}
log_dist(
f"pipeline_parallelism: enabled\n\tparallel_size: {self.pipeline_parallelism['parallel_size']}"
)

if "expert_parallelism" in cfg and cfg["expert_parallelism"]["enable"] is True:
if cfg["expert_parallelism"]["parallel_size"] is None:
log_dist(
"expert_parallelism can not be usable because parallel_size is required.",
logging.WARNING,
raise ValueError(
"expert_parallelism can not be usable because parallel_size is required."
)
else:
self.expert_parallelism = cfg["expert_parallelism"]
if "params" not in self.expert_parallelism:
self.expert_parallelism["params"] = {}
log_dist(
f"expert_parallelism: enabled\n\tparallel_size: {self.expert_parallelism['parallel_size']}"
)
Expand Down
Loading

0 comments on commit 7db1d1b

Please sign in to comment.