Skip to content

Commit

Permalink
Use kernel_func_name from aotCompiler (#66337)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#66337

Right now, assembly code generated for the a given method from the model is named wrapper or func by default. The function name is then replaced with a proper kernel_func_name after target specific assembly is generated.
This PR propagates a desired kernel_func_name right from aotCompiler API so that the generated function has the needed name that doesn't need to be replaced later.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31514095

Pulled By: priyaramani

fbshipit-source-id: b70c8e2c733600a435cd4e8b32092d37b7bf7de5
  • Loading branch information
priyaramani authored and facebook-github-bot committed Oct 23, 2021
1 parent 64c68ed commit 7b55dc8
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 17 deletions.
22 changes: 12 additions & 10 deletions binaries/aot_model_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ std::string getNncKernelId() {
":" + version_token;
}

std::string getNncKernelFuncName(const std::string& method_name) {
return "nnc_" + FLAGS_model_name + "_" + FLAGS_model_version + "_" + method_name;
}

void writeOutputLlvmAssembly(const std::string& asm_code) {
std::string output_llvm_file_name = FLAGS_output_llvm;
if (output_llvm_file_name.empty()) {
Expand All @@ -108,18 +112,13 @@ c10::IValue preprocess(
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {

std::string output_llvm_file_name = FLAGS_output_llvm;
if (output_llvm_file_name.empty()) {
output_llvm_file_name =
FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll";
}

auto method = mod.get_method(FLAGS_method_name);
auto graph = method.function().graph()->copy();
auto sizes = getInputSizes(method_compile_spec);
auto kernel_func_name = getNncKernelFuncName(FLAGS_method_name);

std::string llvm_asm_code;
auto compiled = torch::jit::mobile::nnc::aotCompile(FLAGS_method_name, graph, sizes);
auto compiled = torch::jit::mobile::nnc::aotCompile(
FLAGS_method_name, graph, sizes, kernel_func_name);
writeOutputLlvmAssembly(compiled.second);

auto func = std::move(compiled.first);
Expand All @@ -141,8 +140,8 @@ int main(int argc, char** argv) {
" --model=<model file>"
" --model_name=<model name>"
" --model_version=<model version>"
" --input_dims='1,3,224,224'"
" [--method_name=<mehhod name>]"
" --input_dims=<input dimensions like '1,3,224,224;2,2'>"
" [--method_name=<method name>]"
" [--output_llvm=<llvm assembly output file path>]"
" [--output_model=<output model file path>]");

Expand All @@ -153,6 +152,9 @@ int main(int argc, char** argv) {
}

CAFFE_ENFORCE(!FLAGS_model.empty(), c10::UsageMessage());
CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage());
CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage());
CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage());

std::string output_model_name = FLAGS_output_model;
if (output_model_name.empty()) {
Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/jit/mobile/nnc/aot_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ std::unique_ptr<Function> compileMethod(
std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
const std::string& method_name,
std::shared_ptr<Graph>& g,
const std::vector<std::vector<int64_t>>& sizes) {
const std::vector<std::vector<int64_t>>& sizes,
const std::string& kernel_func_name) {
GRAPH_DEBUG("Input sizes ", sizes);
GRAPH_DEBUG("Method name ", method_name);

Expand All @@ -111,7 +112,9 @@ std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
GRAPH_DUMP("graph after shape propagation ", g);

std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
std::make_shared<tensorexpr::TensorExprKernel>(g);
std::make_shared<tensorexpr::TensorExprKernel>(
TensorExprKernel(g, {}, false, kernel_func_name));

const std::string compiled_assembly = kernel->getCodeText();

auto func = compileMethod(kernel, method_name, sizes);
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/mobile/nnc/aot_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ namespace nnc {
TORCH_API std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
const std::string& method_name,
std::shared_ptr<Graph>& subgraph,
const std::vector<std::vector<int64_t>>& sizes);
const std::vector<std::vector<int64_t>>& sizes,
const std::string& kernel_func_name = "func");

} // namespace nnc
} // namespace mobile
Expand Down
8 changes: 5 additions & 3 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,17 +1172,19 @@ void TensorExprKernel::compile() {
stmt,
bufferArgs_,
device_,
SubgraphUtils::generateNameForGraph(graph_));
kernel_func_name_);
}

TensorExprKernel::TensorExprKernel(
const std::shared_ptr<Graph>& subgraph,
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings,
bool pre_alloc /*= false*/)
bool pre_alloc /*= false*/,
const std::string& kernel_func_name)
: graph_(subgraph),
code_(subgraph, ""),
custom_lowerings_(std::move(custom_lowerings)),
pre_alloc_(pre_alloc) {
pre_alloc_(pre_alloc),
kernel_func_name_(kernel_func_name) {
allow_fallback_ = fallbackAllowed();
if (!allow_fallback_) {
compile();
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/tensorexpr/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class TORCH_API TensorExprKernel {
const std::shared_ptr<Graph>& subgraph,
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings =
{},
bool pre_alloc = false);
bool pre_alloc = false,
const std::string& kernel_func_name = "func");

void run(Stack& stack);
void runFast(
Expand Down Expand Up @@ -235,6 +236,7 @@ class TORCH_API TensorExprKernel {

std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings_;
bool pre_alloc_{false};
const std::string& kernel_func_name_;
};

TORCH_API int& getTECudaPointwiseLoopLevels();
Expand Down

0 comments on commit 7b55dc8

Please sign in to comment.