Skip to content

Commit

Permalink
[CINN] add compile error handler (PaddlePaddle#57198)
Browse files Browse the repository at this point in the history
* [CINN] add compile error handler

* [Fix] cinn_compiler return value

* Refine code style
  • Loading branch information
BiynXu authored Sep 18, 2023
1 parent cdd0461 commit ea35c7f
Show file tree
Hide file tree
Showing 17 changed files with 730 additions and 78 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/auto_schedule/measure/simple_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ BuildResult SimpleBuilder::Build(const MeasureInput& input) {

BuildResult build_result;
build_result.compiled_scope = graph_compiler_->GetScope().get();
build_result.runtime_program = std::move(compiled_result.runtime_program);
build_result.runtime_program = std::move(compiled_result.RuntimeProgram());
return build_result;
}

Expand Down
49 changes: 32 additions & 17 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#ifdef CINN_WITH_CUDA
Expand All @@ -39,6 +40,7 @@ PD_DECLARE_string(cinn_dump_group_instruction);
namespace cinn {
namespace backends {
using ir::Module;
using CompilationStatus = hlir::framework::CompilationStatus;

static constexpr int DebugLogMaxLen = 30000;

Expand Down Expand Up @@ -88,9 +90,13 @@ void CompilationInfoDumper::DumpLoweredFunc() {
if (FLAGS_cinn_dump_group_lowered_func.empty()) {
return;
}
for (int idx = 0; idx < info_.lowered_funcs.size(); ++idx) {
for (int idx = 0; idx < info_.Size(); ++idx) {
std::stringstream content;
content << info_.lowered_funcs[idx].front();
if (info_.Status(idx) > CompilationStatus::LOWERING_FAIL) {
content << info_.LoweredFuncs(idx).front();
} else {
content << "[No lowered func generated]\n\n" << info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_lowered_func,
idx,
"lowered_function.txt",
Expand All @@ -102,35 +108,44 @@ void CompilationInfoDumper::DumpSourceCode() {
if (FLAGS_cinn_dump_group_source_code.empty()) {
return;
}
for (int idx = 0; idx < info_.source_codes.size(); ++idx) {
Dump(FLAGS_cinn_dump_group_source_code,
idx,
"source_code.cu",
info_.source_codes[idx]);
for (int idx = 0; idx < info_.Size(); ++idx) {
std::string dump_str;
if (info_.Status(idx) > CompilationStatus::CODEGEN_JIT_FAIL) {
dump_str = info_.SourceCode(idx);
} else {
dump_str = "[No source code generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_source_code, idx, "source_code.cu", dump_str);
}
}

void CompilationInfoDumper::DumpPtxCode() {
if (FLAGS_cinn_dump_group_ptx.empty()) {
return;
}
for (int idx = 0; idx < info_.source_ptxs.size(); ++idx) {
Dump(FLAGS_cinn_dump_group_ptx,
idx,
"source_ptx.ptx",
info_.source_ptxs[idx]);
for (int idx = 0; idx < info_.Size(); ++idx) {
std::string dump_str;
if (info_.Status(idx) > CompilationStatus::CODEGEN_JIT_FAIL) {
dump_str = info_.SourcePtx(idx);
} else {
dump_str = "[No source ptxs generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_ptx, idx, "source_ptx.ptx", dump_str);
}
}

void CompilationInfoDumper::DumpInstruction() {
if (FLAGS_cinn_dump_group_instruction.empty()) {
return;
}
for (int idx = 0; idx < info_.instructions.size(); ++idx) {
Dump(FLAGS_cinn_dump_group_instruction,
idx,
"instruction.txt",
info_.instructions[idx]->DumpInstruction());
for (int idx = 0; idx < info_.RuntimeInstructions().size(); ++idx) {
std::string dump_str;
if (info_.RuntimeInstruction(idx).get() != nullptr) {
dump_str = info_.RuntimeInstruction(idx)->DumpInstruction();
} else {
dump_str = "[No instruction generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_instruction, idx, "instruction.txt", dump_str);
}
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ gather_srcs(
op_lowering_util.cc
op_lowering_impl.cc
accuracy_checker.cc
visualize_helper.cc)
visualize_helper.cc
compile_error.cc)

# TODO(Aurelius84): new_ir_compiler depends on pd_op_dialect and could
# not found under CINN_ONLY mode
Expand Down
41 changes: 41 additions & 0 deletions paddle/cinn/hlir/framework/compile_error.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/utils/enum_string.h"

namespace cinn {
namespace hlir {
namespace framework {

std::string CompileErrorHandler::GeneralErrorMessage() const {
std::ostringstream os;
os << "[CompileError] An error occurred during compilation with the error "
"code: "
<< utils::Enum2String(status_) << std::endl;
os << "(at " << file_ << " : " << line_ << ")" << std::endl;
os << indent_str_ << "[Error info] " << this->err_msg_ << std::endl;
return os.str();
}

std::string CompileErrorHandler::DetailedErrorMessage() const {
std::ostringstream os;
os << GeneralErrorMessage();
os << indent_str_ << "[Detail info] " << detail_info_ << std::endl;
return os.str();
}

} // namespace framework
} // namespace hlir
} // namespace cinn
68 changes: 68 additions & 0 deletions paddle/cinn/hlir/framework/compile_error.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/utils/error.h"

namespace cinn {
namespace hlir {
namespace framework {

/**
* This handler is used to deal with the errors during the compilation process
*/
class CompileErrorHandler : public utils::ErrorHandler {
public:
/**
* \brief constructor
* \param err_msg the error message
*/
explicit CompileErrorHandler(const CompilationStatus& status,
const std::string& err_msg,
const std::string& detail_info,
const char* file,
int line)
: status_(status),
err_msg_(err_msg),
detail_info_(detail_info),
file_(file),
line_(line) {}

/**
* \brief Returns a short error message corresponding to the kGeneral error
* level.
*/
std::string GeneralErrorMessage() const;

/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
std::string DetailedErrorMessage() const;

CompilationStatus Status() const { return status_; }

private:
CompilationStatus status_;
std::string err_msg_;
std::string detail_info_;
const char* file_;
int line_;
};

} // namespace framework
} // namespace hlir
} // namespace cinn
13 changes: 7 additions & 6 deletions paddle/cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/utils/enum_string.h"
#include "paddle/cinn/utils/profiler.h"

namespace cinn {
Expand All @@ -44,7 +45,7 @@ std::unique_ptr<Program> GraphCompiler::Build(const std::string& code) {
compilation_context_.with_instantiate_variables = true;

auto&& result = Build(&compilation_context_);
return std::move(result.runtime_program);
return result.RuntimeProgram();
}

CompilationResult GraphCompiler::Build(CompilationContext* context) {
Expand All @@ -64,22 +65,22 @@ CompilationResult GraphCompiler::Build(CompilationContext* context) {
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();

if (context->stage != CompilationStage::DEFAULT) {
if (context->stage != CompilationStage::DEFAULT || !result.IsSuccess()) {
return result;
}

if (context->remove_unused_variables) {
RemoveInvalidVariables(context, result.instructions);
RemoveInvalidVariables(context, result.RuntimeInstructions());
}

if (context->with_buffer_handle_instruction_inserted) {
VLOG(3) << "option.with_buffer_handle_instruction_inserted enable";
InsertBufferHandlers(context, &result.instructions);
InsertBufferHandlers(context, &result.instructions_);
}
VLOG(2) << "Compile With Parallel Compiler Done!";

result.runtime_program =
std::make_unique<Program>(context->scope, std::move(result.instructions));
result.SetRuntimeProgram(std::make_unique<Program>(
context->scope, std::move(result.instructions_)));
return result;
}

Expand Down
10 changes: 5 additions & 5 deletions paddle/cinn/hlir/framework/graph_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
GraphCompiler gc_disable(context_disable);
// disable with_buffer_handle_instruction_inserted: only 1 instruction
auto runtime_program_disable =
gc_disable.Build(&context_disable).runtime_program;
gc_disable.Build(&context_disable).RuntimeProgram();
ASSERT_EQ(runtime_program_disable->size(), 1);
const auto& computation_instr_disable =
runtime_program_disable->GetRunInstructions().front();
Expand All @@ -87,7 +87,7 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
context_enable.with_buffer_handle_instruction_inserted = true;
GraphCompiler gc_enable(context_enable);
auto runtime_program_enable =
gc_enable.Build(&context_enable).runtime_program;
gc_enable.Build(&context_enable).RuntimeProgram();
const auto& instructions = runtime_program_enable->GetRunInstructions();
ASSERT_EQ(instructions.size(), 3);

Expand Down Expand Up @@ -254,7 +254,7 @@ TEST(GraphCompilerTest, TestLowering) {
GraphCompiler gc(context);
CompilationResult result = gc.Lowering();

ASSERT_EQ(result.status, CompilationStatus::SUCCESS);
ASSERT_EQ(result.Status(), CompilationStatus::SUCCESS);
}

TEST(GraphCompilerTest, TestCodegenAndJit) {
Expand All @@ -274,7 +274,7 @@ TEST(GraphCompilerTest, TestCodegenAndJit) {
GraphCompiler gc(context);
CompilationResult result = gc.CodegenAndJit();

ASSERT_EQ(result.status, CompilationStatus::SUCCESS);
ASSERT_EQ(result.Status(), CompilationStatus::SUCCESS);
}

TEST(GraphCompilerTest, TestBuildInstruction) {
Expand All @@ -294,7 +294,7 @@ TEST(GraphCompilerTest, TestBuildInstruction) {
GraphCompiler gc(context);
CompilationResult result = gc.BuildInstruction();

ASSERT_EQ(result.status, CompilationStatus::SUCCESS);
ASSERT_EQ(result.Status(), CompilationStatus::SUCCESS);
}

#endif
Expand Down
Loading

0 comments on commit ea35c7f

Please sign in to comment.