Skip to content

Commit

Permalink
add continue train docs and check applied device (alibaba#29)
Browse files Browse the repository at this point in the history
Co-authored-by: xianyan.xianyanjia <[email protected]>
  • Loading branch information
SeaOfOcean and SeaOfOcean authored Nov 23, 2023
1 parent c2c3098 commit 2de8eb5
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 4 deletions.
4 changes: 2 additions & 2 deletions chatlearn/models/rlhf_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def after_episode(self):
Operations after one episode.
"""

def build_dataset(self, train_prompts):
def build_dataset(self, train_prompts, is_eval=False):
"""
Build prompt dataset
Expand Down Expand Up @@ -336,7 +336,7 @@ def _build_dataloader(self, data, batch_size, sample_per_episode_per_replica=-1,
assert sample_per_episode_per_replica > 0, \
"The dataloader for training expect positive sample_per_episode_per_replica, "\
f"but got {sample_per_episode_per_replica}"
dataset = self.build_dataset(data) # pylint: disable=assignment-from-no-return
dataset = self.build_dataset(data, is_eval) # pylint: disable=assignment-from-no-return
assert hasattr(dataset, 'collate_fn'), \
f"{dataset.__class__.__name__} has no attribute `collate_fn`. If you would like "\
"to use the default collate_fn to batch samples, try adding `self.collate_fn = None` "\
Expand Down
25 changes: 25 additions & 0 deletions chatlearn/schedule/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ def __init__(self, rlhf_models, resouce_manager, global_args):
self._parameter_sync_model_mapping = {}
self.model_packs = []

def _get_total_device_required(self):
total_device = 0
remote_states = set()
for group in self.rlhf_args.colocation:
colocate_models = [self._name2distmodel[name] for name in group]
max_device = max(m.total_device for m in colocate_models)
total_device += max_device
for name in group:
remote_states.add(name)
for model in self.dist_models:
# place non-colocate models
if model.name not in remote_states:
max_device = model.total_device
total_device += max_device
return total_device

def remote(self) -> list:
"""
convert model to remote
Expand All @@ -65,6 +81,15 @@ def remote(self) -> list:
self.dist_models.append(dist_model)
self._name2distmodel[model.name] = dist_model

total_device_required = self._get_total_device_required()
if total_device_required > self.resouce_manager.total_gpu:
raise RuntimeError(f"The number of required gpus for current job is {total_device_required}, " + \
f"while the number of applied gpus is {self.resouce_manager.total_gpu}")
if self.resouce_manager.total_gpu > total_device_required:
logger.warning(f"The number of applied gpus is {self.resouce_manager.total_gpu}, " + \
f"while the number of required gpus is {total_device_required}, " + \
f"there is {self.resouce_manager.total_gpu - total_device_required} wasted gpus")

for group in self.rlhf_args.colocation:
colocate_models = [self._name2distmodel[name] for name in group]
self.place_models_to_remote_devices(colocate_models)
Expand Down
3 changes: 3 additions & 0 deletions chatlearn/schedule/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""resource manager"""

import os
import time

import ray
Expand All @@ -36,6 +37,8 @@ def __init__(self, models):
resource = ray.nodes()[0]['Resources']
self.gpu_per_node = int(resource['GPU'])
self.cpu_per_node = int(resource['CPU'])
self.nnode = int(os.environ['WORLD_SIZE'])
self.total_gpu = self.nnode * self.gpu_per_node

def get_placement_group_state(self, pg):
try:
Expand Down
6 changes: 6 additions & 0 deletions chatlearn/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,12 @@ def _validate_params(self):
if model_args.batch_generation.min_prompt_length:
logger.info(f"Enable batch generation: \
min_prompt_length = {model_args.batch_generation.min_prompt_length}")
if self.rlhf_args.colocation and len(self.rlhf_args.colocation) > 0:
model_set = set()
for colocate_models in self.rlhf_args.colocation:
for model_name in colocate_models:
assert model_name not in model_set, f"Model {model_name} should only appear once in colocation group"
model_set.add(model_name)
logger.info(f"Env Config: \n{self.env_args}")
logger.info(f"RLHF Config: \n{self.rlhf_args}")
for name, model_args in self.models.items():
Expand Down
Binary file added docs/images/fault.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/zh/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ ChatLearn Documentation

tutorial/tutorial_llama2
tutorial/tutorial_bloom
tutorial/continue_train

|
|
Expand Down
52 changes: 52 additions & 0 deletions docs/zh/tutorial/continue_train.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# 续跑和容错

RLHF 任务涉及到多模型的计算和交互,随着模型规模的增大和计算资源的增加,由于依赖的软件栈和硬件环境都有可能出现偶发异常,会导致任务停止运行。
为了保障被中断的任务可以恢复状态进行自动续跑,ChatLearn提供了续跑的功能,结合 PAI-DLC 的 AIMaster,可以实现自动错误检测和续跑功能。

## 配置 ChatLearn RLHF 续跑

RLHF 任务的续跑需要考虑以下几点:
1. 数据进度的记录和恢复; 对于数据状态的记录,用户需要在训练配置主文件如 `rlhf.yaml` 中配置 `data_checkpoint_path`
如果 `data_checkpoint_path` 不为空,则 ChatLearn 会记录当前的数据进度,并在每次 `save_checkpoint` 的同时存储 data checkpoint。
2. 训练状态比如 episode、iteration 等信息的恢复;当用户配置了 `data_checkpoint_path`,同时文件夹中存在对应的 data checkpoint,ChatLearn 会自动恢复训练状态到当前最新的checkpoint状态。
并将模型的 `resume_training` 变量设为 `True`
3. checkpoint的加载;当 `resume_training==True`, 对于 RLHF 中的几个模型,`reference``reward` 加载的 checkpoint 保持不变。
`ppo_policy``ppo_value` 需要加载训练中存储的checkpoint,而不是原始初始化的checkpoint。 因此需要在 `setup` 阶段做特殊处理。

```python
if self.resume_training:
self.args.load = get_args().save
self.args.load_iteration = -1
self.args.no_load_optim = False
self.args.no_load_rng = False
self.args.no_load_args = False
self.args.no_load_scheduler = False
self.args.finetune = False
```

更多详情可以参考 `examples/megatron/step3_rlhf/run_scripts/llama2/run_7b_7b.sh`

如果用户在程序中配置了 `data_checkpoint_path` ,但是不想打开续跑功能,则也可以通过配置 `enable_resume_training: False` 来关闭此功能。

## 和 DLC AIMaster 结合实现容错和自动续跑

DLC提供了基于AIMaster的容错监控功能。AIMaster是一个任务级别的组件,当任务开启AIMaster的容错监控功能后,
会拉起一个AIMaster实例和任务其他实例一起运行,起到任务监控、容错判断、资源控制的作用。

用户可以通过结合 AIMaster 的容错功能和 ChatLearn 的续跑功能来实现训练任务的自动续跑。

以下为容错监控的配置示例,在这个配置中打开了 hang 检测和错误检测,
当hang 超过 1 个小时或者当 AIMaster 检测到错误,会将任务自动重启,最大重启次数为3次。

![aimaster](../../images/fault.png)

更多的容错配置请参考 DLC [容错文档](https://help.aliyun.com/zh/pai/user-guide/fault-tolerance-monitoring-based-on-aimaster?spm=a2c4g.11186623.0.0.12011976WAncyo)









2 changes: 1 addition & 1 deletion examples/megatron/models/old_policy_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def setup(self):
get_args().latest_entropies = []
return 'ok'

def build_dataset(self, train_prompts):
def build_dataset(self, train_prompts, is_eval=False):
'''
framework source: dataset = self.build_dataset(data)
:param train_prompts: all train prompts used in this training run??
Expand Down
2 changes: 1 addition & 1 deletion examples/megatron/models/vllm_policy_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def setup(self):

return 'ok'

def build_dataset(self, train_prompts):
def build_dataset(self, train_prompts, is_eval=False):
'''
framework source: dataset = self.build_dataset(data)
:param train_prompts: all train prompts used in this training run??
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fi
LOG_DIR=${OUTPUT_DIR}/logs/${exp_name}
TENSORBOARD_DIR=${OUTPUT_DIR}/tensorboard/${exp_name}
SAVE_DIR=${OUTPUT_DIR}/save_model/${exp_name}
export data_checkpoint_path=${OUTPUT_DIR}/save_model/${exp_name}/data_checkpoint

mkdir -p ${LOG_DIR}

Expand Down

0 comments on commit 2de8eb5

Please sign in to comment.