diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.h b/paddle/fluid/framework/new_executor/instruction/instruction_util.h index 2887d3c4aca2f..b1977721504b2 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.h @@ -68,5 +68,8 @@ void HandleForInplaceOp(pir::Operation* op, InstructionBase* instr); void ShareVarBuffer(const Variable* src_var, Variable* dst_var); + +std::unordered_set GetInternalInputs(pir::Block* block); +std::unordered_set GetInternalOutputs(pir::Block* block); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_pylayer_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_pylayer_op.cc index 6c3667114cbbe..e3ebe5006db82 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_pylayer_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_pylayer_op.cc @@ -20,6 +20,7 @@ paddle::dialect::PyLayerOp #include "paddle/fluid/pir/dialect/operator/ir/manual_pylayer_op.h" +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" @@ -196,6 +197,119 @@ void PyLayerOp::UpdateOutput() { VerifyRegion(); } +void PyLayerOp::UpdateInputOutput() { + PADDLE_ENFORCE_NOT_NULL(*this, + common::errors::InvalidArgument( + "The pylayer_op in PyLayerOp used to update " + "output can't be nullptr")); + auto program_block = parent(); + PADDLE_ENFORCE_NOT_NULL( + program_block, + common::errors::InvalidArgument( + "The parent block of pylayer_op which used to update " + "output can't be nullptr")); + + std::unordered_set global_block_inner_inputs; + global_block_inner_inputs = + paddle::framework::GetInternalInputs(program_block); + + pir::Block &block = forward_block(); + std::vector input_values = inputs(); + std::vector output_values = outputs(); + + std::unordered_set inner_inputs; + inner_inputs = paddle::framework::GetInternalInputs(&block); + std::unordered_set inner_outputs; + inner_outputs = paddle::framework::GetInternalOutputs(&block); + + for (size_t arg_id = 0; arg_id < block.args_size();) { + if (block.arg(arg_id) && (!inner_inputs.count(block.arg(arg_id)))) { + block.EraseArg(arg_id); + continue; + } + ++arg_id; + } + + bool need_build_new_pylayer = false; + std::vector new_pylayer_output_types; + std::vector new_pylayer_inputs; + std::vector new_pylayer_yield_inputs; + + for (auto value : input_values) { + if (value && (!inner_inputs.count(value))) { + need_build_new_pylayer = true; + continue; + } + new_pylayer_inputs.push_back(value); + } + + std::vector old_pylayer_outputs_map_to_new_pylayer_outputs_index; + + if (block.back().isa()) { + std::vector yield_inputs = block.back().operands_source(); + PADDLE_ENFORCE_EQ( + yield_inputs.size(), + output_values.size(), + common::errors::Unimplemented( + "YieldOp's input size(%d) must be equal with " + "PyLayer's outpus's output size %d. If Pass modify PyLayer's " + "block, the Pass should not modify YieldOp, because YieldOp must " + "update with PyLayer outputs together. Otherwise, when updating " + "PyLayer outputs, the mapping relationship between the new PyLayer " + "and the old PyLayer outputs cannot be known. Therefore, we can't " + "use ReplaceAllUsesWith update Value of PyLayer outputs.", + yield_inputs.size(), + output_values.size())); + int index = 0; + for (size_t i = 0; i < yield_inputs.size(); i++) { + if (yield_inputs[i] && (!inner_outputs.count(yield_inputs[i]))) { + PADDLE_ENFORCE_EQ( + global_block_inner_inputs.count(output_values[i]), + 0, + common::errors::Unimplemented( + "The PyLayer's output not defined in PyLayer's block, " + "but used in global block.")); + need_build_new_pylayer = true; + old_pylayer_outputs_map_to_new_pylayer_outputs_index.push_back(-1); + continue; + } + new_pylayer_output_types.push_back(yield_inputs[i].type()); + new_pylayer_yield_inputs.push_back(yield_inputs[i]); + old_pylayer_outputs_map_to_new_pylayer_outputs_index.push_back(index++); + } + } else { + if (!output_values.empty()) { + PADDLE_THROW(common::errors::Unimplemented( + "The last op of PyLayer block, is not yield_op, but a %s", + block.back().name())); + } + } + + if (need_build_new_pylayer) { + ::pir::IrContext *ctx = ::pir::IrContext::Instance(); + block.pop_back(); + ::pir::Builder builder = ::pir::Builder(ctx, &block); + builder.SetInsertionPointToBlockEnd(&block); + builder.Build(new_pylayer_yield_inputs); + + ::pir::Builder builder2 = ::pir::Builder(ctx, program_block); + builder2.set_insertion_point(&(**this)); + auto new_pylayer = builder2.Build(new_pylayer_inputs, + forward_region().TakeBack(), + backward_function_id()); + for (size_t i = 0; + i < old_pylayer_outputs_map_to_new_pylayer_outputs_index.size(); + i++) { + if (old_pylayer_outputs_map_to_new_pylayer_outputs_index[i] != -1) { + output_values[i].ReplaceAllUsesWith(new_pylayer.result( + old_pylayer_outputs_map_to_new_pylayer_outputs_index[i])); + } + } + pir::Block::Iterator iter = **this; + program_block->erase(iter); + } +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_pylayer_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_pylayer_op.h index 0deea321d8b18..0cd51b7bb304e 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_pylayer_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_pylayer_op.h @@ -49,6 +49,14 @@ class PyLayerOp : public pir::Op { } return input_values; } + + std::vector outputs() const { + std::vector output_values; + for (size_t index = 0; index < num_results(); ++index) { + output_values.push_back(result(index)); + } + return output_values; + } pir::Value input(size_t index) const { PADDLE_ENFORCE_LT( index, @@ -75,6 +83,7 @@ class PyLayerOp : public pir::Op { void VerifyRegion(); void UpdateOutput(); + void UpdateInputOutput(); }; } // namespace dialect diff --git a/paddle/fluid/pybind/control_flow_api.cc b/paddle/fluid/pybind/control_flow_api.cc index f883af37937c3..e4f8ae920811d 100644 --- a/paddle/fluid/pybind/control_flow_api.cc +++ b/paddle/fluid/pybind/control_flow_api.cc @@ -38,6 +38,8 @@ #include "paddle/fluid/pybind/python_callable_registry.h" +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" + namespace py = pybind11; using paddle::dialect::ApiBuilder; using paddle::dialect::AssertOp; @@ -95,6 +97,20 @@ void BindPyLayerOp(py::module* m) { return ApiBuilder::Instance().GetBuilder()->Build( inputs, std::vector{}, -1); }); + m->def("updata_pylayer_op", [](pir::Program* program) -> void { + std::unordered_set global_block_inner_inputs; + global_block_inner_inputs = + paddle::framework::GetInternalInputs(program->block()); + + for (auto iter = program->block()->begin(); + iter != program->block()->end();) { + pir::Operation* op_item = &(*iter); + ++iter; + if (op_item->isa()) { + op_item->dyn_cast().UpdateInputOutput(); + } + } + }); py::class_ pylayer_op(*m, "PyLayerOp", R"DOC( TODO(MarioLulab): Add some docs for pd_op.pylayer )DOC"); @@ -103,6 +119,7 @@ void BindPyLayerOp(py::module* m) { &PyLayerOp::forward_block, return_value_policy::reference) .def("update_output", &PyLayerOp::UpdateOutput) + .def("update_input_and_output", &PyLayerOp::UpdateInputOutput) .def( "as_operation", &PyLayerOp::operation, return_value_policy::reference) .def("id", diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 38b330de7b768..9734d3ee704be 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -369,58 +369,97 @@ def forward( local_tensor_list, local_mesh_list, local_placements, - idx, - global_dims, mesh, placements, + global_dims, + idx=-1, ): - local_tensor = local_tensor_list[idx] - if local_tensor.is_dist(): - local_mesh = local_tensor.process_mesh - local_val = local_tensor._local_value() - else: - local_val = local_tensor - local_mesh = None - - ctx.global_mesh = copy.deepcopy(mesh) - ctx.placements = copy.deepcopy(placements) - ctx.local_dims = local_tensor.shape - ctx.local_mesh_list = copy.deepcopy(local_mesh_list) - ctx.local_placements = copy.deepcopy(local_placements) + # NOTE: _local_value and Paddle.Tensor() is only supported in dynamic mode + if paddle.in_dynamic_mode(): + local_tensor = local_tensor_list[idx] + if local_tensor.is_dist(): + local_mesh = local_tensor.process_mesh + local_val = local_tensor._local_value() + else: + local_val = local_tensor + local_mesh = None + + ctx.save_for_backward( + copy.deepcopy(mesh), # global_mesh + local_tensor.shape, # local_dims + copy.deepcopy(local_mesh_list), # local_mesh_list + copy.deepcopy(local_placements), # local_placements + ) - place = paddle.framework._current_expected_place() - place = paddle.framework._get_paddle_place(place) + place = paddle.framework._current_expected_place() + place = paddle.framework._get_paddle_place(place) - global_tensor = paddle.Tensor( - local_val, - dims=global_dims, - process_mesh=mesh, - placements=placements, - place=place, - ) - global_tensor.stop_gradient = local_tensor.stop_gradient - return global_tensor + global_tensor = paddle.Tensor( + local_val, + dims=global_dims, + process_mesh=mesh, + placements=placements, + place=place, + ) + global_tensor.stop_gradient = local_tensor.stop_gradient + return global_tensor + else: + ctx.save_for_backward( + copy.deepcopy(mesh), # global_mesh + copy.deepcopy(placements), # global_placements + copy.deepcopy(local_mesh_list), # local_mesh_list + copy.deepcopy(local_placements), # local_placements + ) + dist_tensor = paddle._C_ops.moe_global_mesh_tensor( + local_tensor_list, + local_mesh_list, + local_placements, + mesh, + placements, + global_dims, + ) + dist_tensor.stop_gradient = local_tensor_list[0].stop_gradient + dist_tensor.persistable = local_tensor_list[0].persistable + return dist_tensor @staticmethod def backward(ctx, grad_tensor): - if ctx.local_mesh_list is None: - return grad_tensor._local_value() - else: - place = paddle.framework._current_expected_place() - place = paddle.framework._get_paddle_place(place) - out = [] - for i, local_mesh in enumerate(ctx.local_mesh_list): - out.append( - paddle.Tensor( - grad_tensor._local_value(), - dims=ctx.local_dims, - process_mesh=local_mesh, - placements=ctx.local_placements, - place=place, + if paddle.in_dynamic_mode(): + global_mesh, local_dims, local_mesh_list, local_placements = ( + ctx.saved_tensor() + ) + if local_mesh_list is None: + return grad_tensor._local_value() + else: + place = paddle.framework._current_expected_place() + place = paddle.framework._get_paddle_place(place) + out = [] + for i, local_mesh in enumerate(local_mesh_list): + out.append( + paddle.Tensor( + grad_tensor._local_value(), + dims=local_dims, + process_mesh=local_mesh, + placements=local_placements, + place=place, + ) ) - ) - out[-1].get_tensor()._unsafe_set_skip_check_mesh(True) - return out + out[-1].get_tensor()._unsafe_set_skip_check_mesh(True) + return out + else: + ( + global_mesh, + global_placements, + local_mesh_list, + local_placements, + ) = ctx.saved_tensor() + return paddle._C_ops.moe_sub_mesh_tensors( + grad_tensor, + local_mesh_list, + local_placements, + global_mesh, + global_placements, + ) def _get_sub_meshes_and_local_placements( @@ -469,6 +508,7 @@ def moe_global_mesh_tensor( local_tensor = local_tensor_list[local_tensor_idx] if paddle.in_dynamic_mode(): + # NOTE: _local_value and Paddle.Tensor() is only supported in dynamic mode if local_coord[0].size == 0: local_tensor_shape = _cal_local_shape( local_tensor_list[0].shape, local_mesh_list[0], local_placements @@ -498,16 +538,18 @@ def moe_global_mesh_tensor( resharded_local_tensor_list, local_mesh_list, local_placements, - local_tensor_idx, - global_dims, mesh, placements, + global_dims, + local_tensor_idx, ) elif paddle.framework.in_pir_mode(): global_dims = _cal_global_shape( local_tensor._local_shape, mesh, placements ) - dist_tensor = paddle._C_ops.moe_global_mesh_tensor( + return paddle.jit.dy2static.py_layer.StaticPyLayer( + _moe_global_mesh_tensor + ).apply( local_tensor_list, local_mesh_list, local_placements, @@ -515,10 +557,6 @@ def moe_global_mesh_tensor( placements, global_dims, ) - dist_tensor.stop_gradient = local_tensor_list[0].stop_gradient - dist_tensor.persistable = local_tensor_list[0].persistable - - return dist_tensor else: raise NotImplementedError( "dtensor_from_local_list() are only supported in dynamic and pir mode." @@ -536,75 +574,115 @@ def forward( global_mesh=None, global_placements=None, ): - ctx.local_mesh_list = copy.deepcopy(local_mesh_list) - ctx.local_placements = local_placements - ctx.local_mesh_dim = local_mesh_dim - ctx.global_mesh = copy.deepcopy(global_mesh) - ctx.global_placements = global_placements - ctx.global_shape = dist_tensor.shape - - if global_mesh is None and global_placements is None: - return dist_tensor._local_value() - else: - if global_mesh is None or global_placements is None: - raise ValueError( - "the args global_mesh and global_placements should be set together" - ) - ori_mesh = dist_tensor.process_mesh - if global_mesh != dist_tensor.process_mesh: - raise ValueError( - "the global_mesh should be the same as dist_tensor's process_mesh." - ) - assert check_placements_equal( - global_placements, dist_tensor.placements - ), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})." - local_shape = _cal_local_shape( - dist_tensor.shape, global_mesh, global_placements - ) - for idx, placement in enumerate(local_placements): - if placement.is_shard(): - shard_dim = placement.get_dim() - local_dim_size = local_shape[shard_dim] - local_shape[shard_dim] = ( - local_dim_size * local_mesh_list[0].shape[idx] - ) - place = paddle.framework._current_expected_place() - place = paddle.framework._get_paddle_place(place) - local_tensor_list = [] - for i, local_mesh in enumerate(local_mesh_list): - local_tensor = paddle.Tensor( - dist_tensor._local_value(), - dims=local_shape, - process_mesh=local_mesh, - placements=local_placements, - place=place, + ctx.save_for_backward( + copy.deepcopy(local_mesh_list), # local_mesh_list, + local_placements, # local_placements, + local_mesh_dim, # local_mesh_dim, + copy.deepcopy(global_mesh), # global_mesh, + global_placements, # global_placements, + dist_tensor.shape, # global_shape, + ) + if paddle.in_dynamic_mode(): + if global_mesh is None and global_placements is None: + return dist_tensor._local_value() + else: + if global_mesh is None or global_placements is None: + raise ValueError( + "the args global_mesh and global_placements should be set together" + ) + ori_mesh = dist_tensor.process_mesh + if global_mesh != dist_tensor.process_mesh: + raise ValueError( + "the global_mesh should be the same as dist_tensor's process_mesh." + ) + assert check_placements_equal( + global_placements, dist_tensor.placements + ), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})." + local_shape = _cal_local_shape( + dist_tensor.shape, global_mesh, global_placements ) - local_tensor.get_tensor()._unsafe_set_skip_check_mesh(True) + for idx, placement in enumerate(local_placements): + if placement.is_shard(): + shard_dim = placement.get_dim() + local_dim_size = local_shape[shard_dim] + local_shape[shard_dim] = ( + local_dim_size * local_mesh_list[0].shape[idx] + ) + + place = paddle.framework._current_expected_place() + place = paddle.framework._get_paddle_place(place) + local_tensor_list = [] + for i, local_mesh in enumerate(local_mesh_list): + local_tensor = paddle.Tensor( + dist_tensor._local_value(), + dims=local_shape, + process_mesh=local_mesh, + placements=local_placements, + place=place, + ) + local_tensor.get_tensor()._unsafe_set_skip_check_mesh(True) + local_tensor.stop_gradient = dist_tensor.stop_gradient + local_tensor_list.append(local_tensor) + return local_tensor_list + elif paddle.framework.in_pir_mode(): + local_tensors = paddle._C_ops.moe_sub_mesh_tensors( + dist_tensor, + local_mesh_list, + local_placements, + global_mesh, + global_placements, + ) + for local_tensor in local_tensors: local_tensor.stop_gradient = dist_tensor.stop_gradient - local_tensor_list.append(local_tensor) - return local_tensor_list + local_tensor.persistable = dist_tensor.persistable + return local_tensors @staticmethod def backward(ctx, *grad_tensor): + ( + local_mesh_list, + local_placements, + local_mesh_dim, + global_mesh, + global_placements, + global_shape, + ) = ctx.saved_tensor() place = paddle.framework._current_expected_place() place = paddle.framework._get_paddle_place(place) - mesh = ctx.global_mesh + mesh = global_mesh process_ids = np.array(mesh.process_ids).reshape(mesh.shape) local_coord = np.where(process_ids == dist.get_rank()) if local_coord[0].size == 0: local_tensor_idx = 0 else: - local_tensor_idx = local_coord[ctx.local_mesh_dim][0] + local_tensor_idx = local_coord[local_mesh_dim][0] local_grad = grad_tensor[local_tensor_idx] - global_tensor = paddle.Tensor( - local_grad._local_value(), - dims=ctx.global_shape, - process_mesh=mesh, - placements=ctx.global_placements, - place=place, - ) - return global_tensor + + if paddle.in_dynamic_mode(): + place = paddle.framework._current_expected_place() + place = paddle.framework._get_paddle_place(place) + global_tensor = paddle.Tensor( + local_grad._local_value(), + dims=global_shape, + process_mesh=mesh, + placements=global_placements, + place=place, + ) + return global_tensor + elif paddle.framework.in_pir_mode(): + global_dims = _cal_global_shape( + local_grad._local_shape, mesh, global_placements + ) + + return paddle._C_ops.moe_global_mesh_tensor( + grad_tensor, + local_mesh_list, + local_placements, + global_mesh, + global_placements, + global_dims, + ) def moe_sub_mesh_tensors( @@ -627,17 +705,17 @@ def moe_sub_mesh_tensors( global_placements, ) elif paddle.framework.in_pir_mode(): - local_tensors = paddle._C_ops.moe_sub_mesh_tensors( + + return paddle.jit.dy2static.py_layer.StaticPyLayer( + _moe_sub_mesh_tensors + ).apply( dist_tensor, local_mesh_list, local_placements, + local_mesh_dim, global_mesh, global_placements, ) - for local_tensor in local_tensors: - local_tensor.stop_gradient = dist_tensor.stop_gradient - local_tensor.persistable = dist_tensor.persistable - return local_tensors else: raise NotImplementedError( "moe_sub_mesh_tensors is only supported in dynamic mode." diff --git a/python/paddle/distributed/auto_parallel/static/pir_pass.py b/python/paddle/distributed/auto_parallel/static/pir_pass.py index 080a9015c4d79..86b90eb211a04 100644 --- a/python/paddle/distributed/auto_parallel/static/pir_pass.py +++ b/python/paddle/distributed/auto_parallel/static/pir_pass.py @@ -405,46 +405,79 @@ def replace_moe_sub_mesh_tensors(op): ) ) + # update pylayer op by removing the unused outputs + def update_pylayer_output(trival_value): + define_op = trival_value.get_defining_op() + if define_op.get_parent_block().parent_op.name() != "pd_op.pylayer": + return + paddle.pir.set_insertion_point(define_op) + fake_value = paddle.static.data( + name="_fake_pylayer_out", + shape=trival_value.shape, + dtype=trival_value.dtype, + ) + fake_value.set_type(trival_value.type()) + trival_value.replace_all_uses_with(fake_value) + + for val in op.results(): + if not val.use_empty(): + update_pylayer_output(val) + assert all(val.use_empty() for val in op.results()) op.erase() +def remove_sub_block_unused_inputs(op): + inputs_size = op.operand_source.num_operands() + inputs = [op.operand_source(i) for i in range(inputs_size)] + # remove unused inputs + + class RemovePasses: + @staticmethod def remove_other_rank_op_pass(dist_program): # pruning op and value not belong to cur rank - cur_rank = paddle.distributed.get_rank() + def prune_op(block): + cur_rank = paddle.distributed.get_rank() + for op in block.ops[::-1]: + if op.name() == "dist_op.moe_sub_mesh_tensors": + replace_moe_sub_mesh_tensors(op) + continue + elif op.name() == "dist_op.moe_global_mesh_tensor": + replace_moe_global_mesh_tensor(op) + continue + elif op.name() == "cf.tuple_push": + stack_create_op = op.operand_source(0).get_defining_op() + if stack_create_op.result(2).use_empty(): + op.erase() + continue + elif op.name() == "cf.yield": + continue + elif op.name() == "pd_op.pylayer": + for pylayer_block in list(op.blocks())[::-1]: + prune_op(pylayer_block) + # update pylayer op's inputs + op.as_pylayer_op().update_input_and_output() + continue + elif op.name() in partition_skip_op_list: + can_delete = True + for val in op.results(): + if not val.use_empty(): + can_delete = False + if can_delete: + op.erase() + continue - for op in dist_program.global_block().ops[::-1]: - if op.name() == "dist_op.moe_sub_mesh_tensors": - replace_moe_sub_mesh_tensors(op) - continue - elif op.name() == "dist_op.moe_global_mesh_tensor": - replace_moe_global_mesh_tensor(op) - continue - elif op.name() == "cf.tuple_push": - stack_create_op = op.operand_source(0).get_defining_op() - if stack_create_op.result(2).use_empty(): + if cur_rank not in op.dist_attr.process_mesh.process_ids: op.erase() - continue - elif op.name() == "cf.yield": - continue - elif op.name() in partition_skip_op_list: - can_delete = True - for val in op.results(): - if not val.use_empty(): - can_delete = False - if can_delete: + elif op.name() == "dist_op.reshard": + assert op.result( + 0 + ).use_empty(), f'There should not have useful dist.reshard op in remove_other_rank_op_pass. but find : {op}' op.erase() - continue - if cur_rank not in op.dist_attr.process_mesh.process_ids: - op.erase() - elif op.name() == "dist_op.reshard": - assert op.result( - 0 - ).use_empty(), f'There should not have useful dist.reshard op in remove_other_rank_op_pass. but find : {op}' - op.erase() + prune_op(dist_program.global_block()) # merge pd.data ops for lr_ops = [] diff --git a/test/auto_parallel/pir/test_moe_api.py b/test/auto_parallel/pir/test_moe_api.py index ceeb2d3c8104d..59b8cc19e7629 100644 --- a/test/auto_parallel/pir/test_moe_api.py +++ b/test/auto_parallel/pir/test_moe_api.py @@ -128,16 +128,13 @@ def check_results( local_dims_mapping, ): # local_tensors_from_dtensor op - self.check_dist_attr(ops[2], local_meshes, local_dims_mapping) - + self.check_dist_attr(ops[4], local_meshes, local_dims_mapping) # dtensor_from_local_list op - self.check_dist_attr(ops[3], [global_mesh], global_dims_mapping) - + self.check_dist_attr(ops[5], [global_mesh], global_dims_mapping) # grad op for dtensor_from_local_list - self.check_dist_attr(ops[8], local_meshes, local_dims_mapping) - + self.check_dist_attr(ops[10], local_meshes, local_dims_mapping) # grad op for local_tensors_from_dtensor op - self.check_dist_attr(ops[9], [global_mesh], global_dims_mapping) + self.check_dist_attr(ops[11], [global_mesh], global_dims_mapping) if __name__ == "__main__":