Skip to content

Commit

Permalink
[CodeGenC] Fix bugs when calling extern functions (apache#7911)
Browse files Browse the repository at this point in the history
  • Loading branch information
leeexyz authored Apr 26, 2021
1 parent 54fdcc5 commit 82fecbf
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 17 deletions.
4 changes: 4 additions & 0 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,10 @@ def load_module(path, fmt=""):
This function will automatically call
cc.create_shared if the path is in format .o or .tar
"""
if os.path.isfile(path):
path = os.path.realpath(path)
else:
raise ValueError("cannot find file %s" % path)

# c++ compiler/linker
cc = os.environ.get("CXX", "g++")
Expand Down
14 changes: 11 additions & 3 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -960,11 +960,19 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) {
return;
} else if (call->op.same_as(builtin::tvm_struct_set())) {
ICHECK_EQ(call->args.size(), 4);
int kind = call->args[2].as<IntImmNode>()->value;
std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], kind);
std::string value = PrintExpr(call->args[3]);
std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1],
call->args[2].as<IntImmNode>()->value);
std::string cast;
if (kind == builtin::kArrStrides) {
// cast void* to int64_t*
cast = call->args[3]->dtype.is_handle() ? "(int64_t*)" : "";
} else if (kind == builtin::kArrDeviceType) {
// cast int to enum
cast = "(DLDeviceType)";
}
this->PrintIndent();
this->stream << ref << " = " << value << ";\n";
this->stream << ref << " = " << cast << value << ";\n";
return;
}
}
Expand Down
17 changes: 10 additions & 7 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,17 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT
// NOTE: cannot rely on GetUnique for global decl_stream declarations
// because it is reset between AddFunction().
std::string packed_func_name = func_name + "_packed";
if (declared_globals_.insert(packed_func_name).second) {
// Still reserve the name among unique names.
ICHECK(GetUniqueName(packed_func_name) == packed_func_name)
<< "Expected name " << packed_func_name << " to not be taken";
decl_stream << "static void* " << packed_func_name << " = NULL;\n";
std::string unique_name;
auto it = declared_globals_.find(packed_func_name);
if (it != declared_globals_.end()) {
unique_name = it->second;
} else {
unique_name = GetUniqueName(packed_func_name);
declared_globals_[packed_func_name] = unique_name;
decl_stream << "static void* " << unique_name << " = NULL;\n";
}
this->PrintGetFuncFromBackend(func_name, packed_func_name);
this->PrintFuncCall(packed_func_name, num_args);
this->PrintGetFuncFromBackend(func_name, unique_name);
this->PrintFuncCall(unique_name, num_args);
} else if (op->op.same_as(builtin::tvm_throw_last_error())) {
this->PrintIndent();
this->stream << "return -1;\n";
Expand Down
6 changes: 3 additions & 3 deletions src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
#ifndef TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_
#define TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_

#include <set>
#include <string>
#include <unordered_map>
#include <vector>

#include "codegen_c.h"
Expand Down Expand Up @@ -63,8 +63,8 @@ class CodeGenCHost final : public CodeGenC {

private:
std::string module_name_;
/* \brief tracks declared global variables which live despite GetUniqueName */
std::set<std::string> declared_globals_;
/* \brief mapping global packed func to the unique name */
std::unordered_map<std::string, std::string> declared_globals_;
/* \brief names of the functions declared in this module */
Array<String> function_names_;
/*! \brief whether to emit asserts in the resulting C code */
Expand Down
30 changes: 26 additions & 4 deletions tests/python/contrib/test_cblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def get_numpy(a, b, bb, transa, transb):
b = b.transpose()
return np.dot(a, b) + bb

def compile(f, name="test_matmul_add", ext=".so"):
path = name + ext
f.export_library(path)
mod = tvm.runtime.load_module(path)
f = mod[name]
return f

def verify(target="llvm"):
if not tvm.testing.device_enabled(target):
print("skip because %s is not enabled..." % target)
Expand All @@ -50,7 +57,10 @@ def verify(target="llvm"):
print("skip because extern function is not available")
return
dev = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
name = "test_matmul_add"
f = tvm.build(s, [A, B, D, bias], target, name=name)
if target == "c":
f = compile(f, name)
a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), dev)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), dev)
Expand All @@ -60,7 +70,8 @@ def verify(target="llvm"):
d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), bb, transa, transb), rtol=1e-5
)

verify()
verify("llvm")
verify("c")


def test_matmul_add():
Expand Down Expand Up @@ -164,6 +175,13 @@ def get_numpy(a, b, transa, transb):
b = b.transpose(0, 2, 1)
return tvm.topi.testing.batch_matmul(a, b)

def compile(f, name="test_batch_matmul", ext=".so"):
path = name + ext
f.export_library(path)
mod = tvm.runtime.load_module(path)
f = mod[name]
return f

def verify(target="llvm"):
if not tvm.testing.device_enabled(target):
print("skip because %s is not enabled..." % target)
Expand All @@ -172,7 +190,10 @@ def verify(target="llvm"):
print("skip because extern function is not available")
return
dev = tvm.cpu(0)
f = tvm.build(s, [A, B, D], target)
name = "test_batch_matmul"
f = tvm.build(s, [A, B, D], target, name=name)
if target == "c":
f = compile(f, name)
a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), dev)
d = tvm.nd.array(np.zeros((batch, n, m), dtype=D.dtype), dev)
Expand All @@ -181,7 +202,8 @@ def verify(target="llvm"):
d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5
)

verify()
verify("llvm")
verify("c")


def test_batch_matmul():
Expand Down
37 changes: 37 additions & 0 deletions tests/python/unittest/test_target_codegen_c_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,47 @@ def check_c():
check_c()


def test_call_packed():
def fake_func(fname="fake.func"):
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
fake_func1 = tvm.tir.call_packed(fname, A[0])

ib.emit(fake_func1)
body = ib.get()
return A, body

def check_global_packed_func():
fname = "fake.func"
A, body = fake_func(fname)
func1 = tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "func1")
B, body = fake_func()
func2 = tvm.tir.PrimFunc([B], body).with_attr("global_symbol", "func2")
mod = tvm.IRModule({"fake_func1": func1, "fake_func2": func2})
fcode = tvm.build(mod, None, "c")
src = fcode.get_source()

# there are two locations calling the packed func
assert src.count(fname) == 2

suffix = "_packed"
packed_func_name = fname + suffix
# func name will be standardized by GetUniqueName and not exists anymore
assert src.find(packed_func_name) == -1

packed_func_real_name = "_".join(fname.split(".")) + suffix
func_declaration = "static void* %s = NULL;" % packed_func_real_name
# src only has 1 valid declaration
assert src.count(func_declaration) == 1

check_global_packed_func()


if __name__ == "__main__":
test_add()
test_add_pipeline()
test_reinterpret()
test_ceil()
test_floor()
test_round()
test_call_packed()

0 comments on commit 82fecbf

Please sign in to comment.