Skip to content

Commit

Permalink
[AOT] Remove lookup parameter function in AOT (apache#7988)
Browse files Browse the repository at this point in the history
* AOT] Remove lookup parameter function in AOT

This PR aims at removing the function call to extract the parameters
within the AOT main function by introducing a tir::lookup_param builtin.

This has different benefits:
- In AOT we now only use the v_handle field
- We save cycles by not calling an intermediate function to extract
local parameters
- We reduce code size, since we don't need to pack a call to extract
parameters and we don't need to produce the lookup_param function
anymore within the compilation unit

Change-Id: I36c2f0724a79606424a4374f4f5cd669bb2a8a55

* addressing comments

Change-Id: I83ba0189f559d310b5a80fe0bcc4d601b490d21a

* retrigger CI

Change-Id: I84ab4a526d1284ded41fe95636e94c15412f6b28
  • Loading branch information
Giuseppe Rossini authored May 20, 2021
1 parent ec3b160 commit 71ff875
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 37 deletions.
8 changes: 8 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ TVM_DLL const Op& tvm_struct_get();
*/
TVM_DLL const Op& tvm_struct_set();

/*!
* \brief See pseudo code
* Type lookup_param(String param_name) {
* return __tvm_param__param_name;
* }
*/
TVM_DLL const Op& lookup_param();

/*!
* \brief See pesudo code
*
Expand Down
25 changes: 4 additions & 21 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,40 +152,23 @@ class AOTExecutorCodegen : public ExprVisitor {
* \return Variable that represents the DLTensor associated with the parameters
*/
tir::Var PackParam(Expr expr) {
// TODO(giuseros): Using call_extern to call into lookup_linked_param. This is because the
// builtin::ret is not supported yet in the c target. Once return is supported we can use
// tvm_call_packed_lowered().
int param_sid = param_storage_ids_[params_by_expr_[expr]];
auto lookup_linked_param_fn = tir::StringImm(::tvm::runtime::symbol::tvm_lookup_linked_param);
auto param_array = te::Var(MakeString("param_", param_sid, "_array"), DataType::Handle());

// Compose the lookup_call using a local stack
Array<tir::Stmt> lookup_call;
auto param_var = te::Var(MakeString("param_", param_sid, "_value"), DataType::Handle());
auto ret_var = te::Var("ret_value", DataType::Handle());
auto ret_code = te::Var("ret_value", DataType::Handle());

lookup_call.push_back(tir::Evaluate(
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{param_var, 0, tir::builtin::kTVMValueContent, ConstInt32(param_sid)})));
lookup_call.push_back(tir::Evaluate(
tvm::tir::Call(DataType::Handle(), tir::builtin::call_extern(),
{lookup_linked_param_fn, param_var, 0, 0, ret_var, ret_code, 0})));
auto ret_var_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
{ret_var, 0, tir::builtin::kTVMValueContent});

// Set the param to the value returned by lookup_call
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[expr])});

tvm::PrimExpr set_param_array =
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{param_array, 0, tir::builtin::kArrData, ret_var_handle});
{param_array, 0, tir::builtin::kArrData, param_handle});
lookup_call.push_back(tir::Evaluate(set_param_array));

tir::Stmt lookup_body = tir::SeqStmt(lookup_call);

// Allocate the DLTensors on the stack
lookup_body = tir::LetStmt(param_var, StackAlloca("arg_value", 1), lookup_body);
lookup_body = tir::LetStmt(ret_var, StackAlloca("arg_value", 1), lookup_body);
lookup_body = tir::LetStmt(ret_code, StackAlloca("arg_value", 1), lookup_body);
lookup_body = tir::LetStmt(param_array, StackAlloca("arg_value", 1), lookup_body);
stmts_.push_back(lookup_body);
return param_array;
Expand Down
5 changes: 5 additions & 0 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,11 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
os << " != ";
this->PrintExpr(op->args[0], os);
os << ")";
} else if (op->op.same_as(builtin::lookup_param())) {
ICHECK_EQ(op->args.size(), 1);
const StringImmNode* str = op->args[0].as<StringImmNode>();
ICHECK(str != nullptr);
os << "__tvm_param__" << str->value;
} else {
LOG(FATAL) << "Unresolved call " << op->op;
}
Expand Down
39 changes: 23 additions & 16 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,7 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) {
CodeGenC::AddFunction(f);
}

void CodeGenCHost::LinkParameters(Map<String, LinkedParam> params) {
PrintFuncPrefix();
stream << " " << tvm::runtime::symbol::tvm_lookup_linked_param
<< "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
<< "int* out_ret_tcode, void* resource_handle) {\n";
ICHECK_EQ(GetUniqueName(tvm::runtime::symbol::tvm_lookup_linked_param),
tvm::runtime::symbol::tvm_lookup_linked_param)
<< "builtin PackedFunc name already taken: " << tvm::runtime::symbol::tvm_lookup_linked_param;
stream << " switch (((int64_t*) args)[0]) {\n"
<< " default:\n"
<< " out_ret_tcode[0] = " << kTVMNullptr << ";\n"
<< " return 0;\n";

function_names_.push_back(tvm::runtime::symbol::tvm_lookup_linked_param);
void CodeGenCHost::DeclareParameters(Map<String, LinkedParam> params) {
for (auto kv : params) {
decl_stream << "\n"
<< "#ifdef __cplusplus\n"
Expand All @@ -93,6 +80,24 @@ void CodeGenCHost::LinkParameters(Map<String, LinkedParam> params) {
<< "#ifdef __cplusplus\n"
<< "} // extern \"C\"\n"
<< "#endif\n";
}
}

void CodeGenCHost::LinkParameters(Map<String, LinkedParam> params) {
PrintFuncPrefix();
stream << " " << tvm::runtime::symbol::tvm_lookup_linked_param
<< "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
<< "int* out_ret_tcode, void* resource_handle) {\n";
ICHECK_EQ(GetUniqueName(tvm::runtime::symbol::tvm_lookup_linked_param),
tvm::runtime::symbol::tvm_lookup_linked_param)
<< "builtin PackedFunc name already taken: " << tvm::runtime::symbol::tvm_lookup_linked_param;
stream << " switch (((int64_t*) args)[0]) {\n"
<< " default:\n"
<< " out_ret_tcode[0] = " << kTVMNullptr << ";\n"
<< " return 0;\n";

function_names_.push_back(tvm::runtime::symbol::tvm_lookup_linked_param);
for (auto kv : params) {
stream << " case " << kv.second->id << ":\n"
<< " ((uint64_t*)out_ret_value)[0] = (uint64_t) (uintptr_t) "
<< ::tvm::runtime::symbol::tvm_param_prefix << kv.first << ";\n"
Expand Down Expand Up @@ -398,12 +403,14 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
cg.AddFunction(f);
}

if (could_have_linked_params) {
if (could_have_linked_params && !aot_executor_fn.defined()) {
ICHECK(found_linked_params) << "-link-params given but none found";
cg.DeclareParameters(linked_params);
cg.LinkParameters(linked_params);
}

if (aot_executor_fn.defined()) {
if (could_have_linked_params && aot_executor_fn.defined()) {
cg.DeclareParameters(linked_params);
cg.AddFunction(aot_executor_fn);
}

Expand Down
1 change: 1 addition & 0 deletions src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class CodeGenCHost final : public CodeGenC {
void AddFunction(const PrimFunc& f);

/*! \brief Add linked parameters, if they are present. */
void DeclareParameters(Map<String, LinkedParam> params);
void LinkParameters(Map<String, LinkedParam> params);

void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
Expand Down
4 changes: 4 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kUpdateState));

TIR_DEFINE_BUILTIN_FUNC(lookup_param)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kUpdateState));

TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
Expand Down

0 comments on commit 71ff875

Please sign in to comment.