Skip to content

Commit

Permalink
[CINN&AutoParallel] Fix some bugs for CINN with AutoParallel (#68731)
Browse files Browse the repository at this point in the history
* fix

* fix
  • Loading branch information
zhangbo9674 authored Oct 16, 2024
1 parent 97c90c4 commit d270cab
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ bool RemoveOp(pir::Operation* op,
pir::PatternRewriter* rewriter,
bool check_dtype = false) {
const auto& IsDynamicShape = [](const pir::Value& value) -> bool {
return value.type().dyn_cast<pir::ShapedTypeInterface>().IsDynamicShape();
auto shape_type = value.type().dyn_cast<pir::ShapedTypeInterface>();
if (shape_type && shape_type.IsDynamicShape()) {
return true;
}
return false;
};
const auto& GetDims = [](const pir::Value& value) -> decltype(auto) {
return value.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,8 @@ pir::Type CastToLocalType(pir::Type type) {
local_types.push_back(CastToLocalType(vec_type[i]));
}
return pir::VectorType::get(vec_type.ir_context(), local_types);
} else if (!type || type.isa<pir::StackType>() ||
type.isa<pir::InletType>() || type.isa<pir::OutletType>()) {
// skip if <<NULL TYPE>>
return type;
} else {
// TODO(2024-Q2) not all value are dist type
PADDLE_THROW(common::errors::PreconditionNotMet(
"The type[%s] is not Dist type.", type));
return type;
}
}

Expand Down
5 changes: 4 additions & 1 deletion paddle/pir/src/pass/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/include/pass/pass.h"
#include <glog/logging.h>

#include "paddle/pir/include/core/ir_context.h"
#include "paddle/pir/include/core/operation.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/core/region.h"
#include "paddle/pir/include/core/verify.h"
#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_instrumentation.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "paddle/pir/include/pattern_rewrite/pattern_match.h"
Expand Down Expand Up @@ -86,6 +88,7 @@ GreedyRewriteConfig PatternRewritePass::InitializeConfig() {
}

void PatternRewritePass::Run(Operation* op) {
VLOG(4) << "Run PatternRewritePass: " << name();
auto [_, num_rewrites] =
ApplyPatternsGreedily(op, patterns_, InitializeConfig());
AddStatistics(num_rewrites);
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,9 @@ def _initialize(self, mode, init_parameters=True):

set_all_ops_op_role(startup_prog.global_block(), OpRole.Forward)
ReshardPasses.apply_reshard_pass(startup_prog)
paddle.base.libpaddle.pir.apply_dist2dense_pass(startup_prog)
remove_unuseful_comm_op_pass(startup_prog)

for op in changed_ouput_op_list:
op.operand_source(0).persistable = True
self._executor.run(startup_prog)
Expand Down

0 comments on commit d270cab

Please sign in to comment.