From ae06bd2918425a8e8972333d041dedd2a824a3e3 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Sat, 27 Apr 2024 00:59:35 +0800 Subject: [PATCH] [AutoParallel-PIR] add UT for Pipeline Parallelism (#63870) --- .../auto_parallel/static/engine.py | 22 +++++++----- ...mi_auto_parallel_dist_to_static_mlp_pir.py | 36 +++++++++++++++++-- ...t_semi_auto_parallel_dist_to_static_pir.py | 1 + 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index ae42a9afaf69a..8507010dd7fdc 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -1024,15 +1024,21 @@ def _init_comm(self): process_group.instantiate() def _init_lr(self): - buffer_tensor = global_scope().var("learning_rate_0").get_tensor() - if not isinstance(self._optimizer._learning_rate, float): - raise TypeError( - "learning rate should be float, got %s here" - % type(self._optimizer._learning_rate) + # hack to find learning_rate op + lr_name = None + for op in self.main_program.global_block().ops: + if ( + op.name() == "pd_op.data" + and 'learning_rate' in op.attrs()["name"] + ): + lr_name = op.attrs()["name"] + break + + if lr_name is not None: + buffer_tensor = global_scope().var(lr_name).get_tensor() + buffer_tensor.set( + np.float32(self._optimizer._learning_rate), self._place ) - buffer_tensor.set( - np.float32(self._optimizer._learning_rate), self._place - ) def _initialize(self, mode, init_parameters=True): self._place = _get_device() diff --git a/test/auto_parallel/pir/semi_auto_parallel_dist_to_static_mlp_pir.py b/test/auto_parallel/pir/semi_auto_parallel_dist_to_static_mlp_pir.py index 5a0a6d7077070..8a81def24ca4b 100644 --- a/test/auto_parallel/pir/semi_auto_parallel_dist_to_static_mlp_pir.py +++ b/test/auto_parallel/pir/semi_auto_parallel_dist_to_static_mlp_pir.py @@ -16,6 +16,7 @@ import random import numpy as np +from mlp_demo import PPDemoNet from test_to_static_pir_program import DemoNet import paddle @@ -26,7 +27,6 @@ BATCH_SIZE = 4 BATCH_NUM = 4 -SEQ_LEN = 2 IMAGE_SIZE = 16 CLASS_NUM = 8 @@ -128,7 +128,6 @@ def run_dynamic(self, layer, opt, dist_loader, is_recompute=False): loss = loss_fn(out, label) loss_list.append(loss.numpy()) loss.backward() - opt.step() opt.clear_grad() return np.array(loss_list) @@ -157,11 +156,44 @@ def test_mp_demo_net(self): dy2static_losses, dist_model = self.run_dy2static( dy2static_layer, dy2static_opt, dist_dataloader ) + dy_losses = self.run_dynamic(dy_layer, dy_opt, dist_dataloader) np.testing.assert_array_equal(dy_losses, dy2static_losses) + def test_pp_demo_net(self): + paddle.disable_static() + self.set_random_seed(self._seed) + mesh1 = dist.ProcessMesh([0], dim_names=["x"]) + mesh2 = dist.ProcessMesh([1], dim_names=["y"]) + data_loader = self.create_data_loader() + + self.set_random_seed(self._seed) + dy_layer = PPDemoNet(mesh1, mesh2) + dy_opt = paddle.optimizer.SGD( + learning_rate=0.1, parameters=dy_layer.parameters() + ) + + paddle.base.set_flags({'FLAGS_enable_pir_api': 1}) + self.set_random_seed(self._seed) + dy2static_layer = PPDemoNet(mesh1, mesh2) + dy2static_opt = paddle.optimizer.SGD( + learning_rate=0.1, parameters=dy2static_layer.parameters() + ) + dist_dataloader = dist.shard_dataloader( + dataloader=data_loader, + meshes=[mesh1, mesh2], + ) + dy2static_losses, dist_model = self.run_dy2static( + dy2static_layer, dy2static_opt, dist_dataloader + ) + + dy_losses = self.run_dynamic(dy_layer, dy_opt, dist_dataloader) + if paddle.distributed.get_rank() == 1: + np.testing.assert_array_equal(dy_losses, dy2static_losses) + def run_test_case(self): self.test_mp_demo_net() + self.test_pp_demo_net() if __name__ == '__main__': diff --git a/test/auto_parallel/pir/test_semi_auto_parallel_dist_to_static_pir.py b/test/auto_parallel/pir/test_semi_auto_parallel_dist_to_static_pir.py index eac31e5c851df..1647d0b1dd5a5 100644 --- a/test/auto_parallel/pir/test_semi_auto_parallel_dist_to_static_pir.py +++ b/test/auto_parallel/pir/test_semi_auto_parallel_dist_to_static_pir.py @@ -31,6 +31,7 @@ def test_mlp(self): {"dtype": "float32", "seed": "2023"}, {"backend": ["gpu"]} ) for envs in envs_list: + # self._log_dir.name = "./log" ckpt_path_tmp = tempfile.TemporaryDirectory() envs["ckpt_path"] = ckpt_path_tmp.name self.run_test_case(