Skip to content

Commit

Permalink
[TensorIR] CreatePrimFunc from TE (apache#7987)
Browse files Browse the repository at this point in the history
Co-authored-by: Tianqi Chen <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
  • Loading branch information
4 people authored May 7, 2021
1 parent 254563a commit 4122a6a
Show file tree
Hide file tree
Showing 5 changed files with 650 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ class IterVar : public ObjectRef {
inline operator PrimExpr() const;

TVM_DEFINE_OBJECT_REF_METHODS(IterVar, ObjectRef, IterVarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterVarNode);
};

// inline implementations
Expand Down
1 change: 1 addition & 0 deletions python/tvm/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .tag import tag_scope
from .operation import placeholder, compute, scan, extern, var, size_var
from .operation import thread_axis, reduce_axis
from .operation import create_prim_func

from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp
from .autodiff import gradient
Expand Down
50 changes: 50 additions & 0 deletions python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
""" Operation class for computation declaration."""
# pylint: disable=invalid-name
from numbers import Integral as _Integral
from typing import List

import tvm._ffi
import tvm.tir
Expand Down Expand Up @@ -426,3 +427,52 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None):
An iteration variable representing the value.
"""
return tvm.tir.IterVar(dom, name, 2, thread_tag, span)


def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc:
"""Create a TensorIR PrimFunc from tensor expression
Parameters
----------
ops : List[Tensor]
The source expression.
Example
-------
We define a matmul kernel using following code:
.. code-block:: python
import tvm
from tvm import te
A = te.placeholder((128, 128), name="A")
B = te.placeholder((128, 128), name="B")
C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C")
func = create_prim_func([A, B, C])
print(tvm.script.asscript(func))
If we want to use TensorIR schedule to do transformations on such kernel,
we need to use `create_prim_func([A, B, C])` to create a schedulable PrimFunc.
The generated function looks like:
.. code-block:: python
@tvm.script.tir
def tir_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128, tir.reduce_axis(0, 128)]) as [i, j, k]:
with tir.init():
C[i, j] = 0.0
C[i, j] += A[i, k] * B[j, k]
Returns
-------
func : tir.PrimFunc
The created function.
"""
if not isinstance(ops, list):
ops = [ops]
return _ffi_api.CreatePrimFunc(ops)
306 changes: 306 additions & 0 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/runtime/registry.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>

#include <algorithm>

#include "../schedule/graph.h"

namespace tvm {
namespace tir {

/*! \brief The helper mutator that transforms ProducerLoad to BufferLoad */
class ProducerToBufferTransformer : public StmtExprMutator {
public:
explicit ProducerToBufferTransformer(const std::unordered_map<te::Tensor, Buffer>& tensor2buffers)
: tensor2buffers_(tensor2buffers) {}

PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
te::Tensor tensor = Downcast<te::Tensor>(op->producer);
auto it = tensor2buffers_.find(tensor);
ICHECK(it != tensor2buffers_.end()) << "IndexError: Cannot find the tensor " << tensor;
const Buffer& buffer = it->second;
return BufferLoad(buffer, op->indices);
}

private:
/*! \brief The Map from Operations to buffers */
const std::unordered_map<te::Tensor, Buffer>& tensor2buffers_;
};

/*! \brief Helper data structural to store informations. */
struct CreateFuncInfo {
/*! \brief The Tensor arg_list. */
Array<te::Tensor> arg_list;
/*! \brief The map from each Tensor to its corresponding buffer. */
std::unordered_map<te::Tensor, Buffer> tensor2buffers;
/*! \brief The transformer from ProducerLoad to BufferLoad. */
ProducerToBufferTransformer transformer;
/*! \brief The buffers should be allocated at function root. */
Array<Buffer> root_alloc;
/*! \brief The count map to make block name unique. */
std::unordered_map<String, int> name_count;

explicit CreateFuncInfo(Array<te::Tensor> arg_list)
: arg_list(std::move(arg_list)), transformer(tensor2buffers) {}

bool IsArg(const te::Tensor& tensor) const {
return std::any_of(arg_list.begin(), arg_list.end(),
[&tensor](const te::Tensor& arg) { return tensor == arg; });
}

String GetUniqueName(const String& prefix) {
String unique_prefix = prefix;
auto it = name_count.find(prefix);
while (name_count.count(unique_prefix)) {
unique_prefix = prefix + "_" + std::to_string(++it->second);
}
name_count[unique_prefix] = 0;
return unique_prefix;
}
};

BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::Tensor& tensor,
Array<PrimExpr> bindings, PrimExpr expr_body,
CreateFuncInfo* info) {
// Step 1. Push_back data_par axis and reduce_axis into block_vars.
Array<IterVar> iter_vars;
std::unordered_map<const VarNode*, PrimExpr> var_map;
iter_vars.reserve(compute_op->axis.size() + compute_op->reduce_axis.size());
auto f_push_block_vars = [&iter_vars, &var_map](const Array<IterVar>& iters) {
for (IterVar iter_var : iters) {
// Create new var
Var new_var(iter_var->var->name_hint, iter_var->var->dtype);
var_map[iter_var->var.get()] = new_var;

IterVarNode* iter_var_node = iter_var.CopyOnWrite();
iter_var_node->dom = Range::FromMinExtent(iter_var->dom->min, iter_var->dom->extent);
iter_var_node->var = new_var;
iter_vars.push_back(iter_var);
}
};
f_push_block_vars(compute_op->axis);
f_push_block_vars(compute_op->reduce_axis);

// Step 2. Declare buffer and update op2buffers
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint());
info->tensor2buffers[tensor] = buffer;

// Step 3. Add Buffer to root_alloc
if (!info->IsArg(tensor)) {
info->root_alloc.push_back(buffer);
}

// Step 4. Calculate indices for BufferStore
Array<PrimExpr> indices;
indices.reserve(compute_op->axis.size());
for (const IterVar& iter_var : compute_op->axis) {
auto it = var_map.find(iter_var->var.get());
ICHECK(it != var_map.end());
indices.push_back(it->second);
}

// Step 5. Create block body.
Optional<Stmt> init = NullOpt;
Stmt body;
if (const auto* reduce = expr_body.as<ReduceNode>()) {
// Case 1. Reduce compute
ICHECK_EQ(reduce->source.size(), 1);
const PrimExpr& lhs = BufferLoad(buffer, indices);
const PrimExpr& rhs = Substitute(info->transformer(reduce->source[0]), var_map);
ICHECK(lhs->dtype == rhs->dtype);
body = BufferStore(buffer, reduce->combiner.get()->operator()({lhs}, {rhs})[0], indices);
init = BufferStore(buffer, reduce->combiner->identity_element[0], indices);
} else {
// Case 2. Data parallel compute
body = BufferStore(buffer, Substitute(info->transformer(expr_body), var_map), indices);
}

// Step 6. Add script_parsing_detect_access attr for auto complete the whole IR.
Map<String, ObjectRef> annotations = compute_op->attrs;
annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3));

// Step 7. Create Block and BlockRealize.
return BlockRealize(/*iter_values=*/std::move(bindings),
/*predicate=*/Bool(true),
/*block=*/
Block(/*iter_vars=*/std::move(iter_vars),
/*reads=*/{},
/*writes=*/{},
/*name_hint=*/info->GetUniqueName(tensor->GetNameHint()),
/*body=*/std::move(body),
/*init=*/std::move(init),
/*alloc_buffers=*/{},
/*match_buffers=*/{},
/*annotations=*/std::move(annotations)));
}

Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info) {
// Step 1. Creating loop vars for block bindings.
Array<IterVar> axes = compute_op->axis;
axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end());
Array<PrimExpr> bindings;
for (size_t i = 0; i < axes.size(); ++i) {
bindings.push_back(Var("i" + std::to_string(i)));
}
// Step 2. Generate block bodies.
Array<Stmt> seq_stmt;
for (int i = 0; i < compute_op->num_outputs(); ++i) {
const te::Tensor& tensor = compute_op.output(i);
PrimExpr expr_body = compute_op->body[i];
seq_stmt.push_back(
GenerateBlockFromTensor(compute_op, tensor, bindings, std::move(expr_body), info));
}
Stmt body = SeqStmt::Flatten(seq_stmt);

// Step 3. Generate loop nesting.
for (size_t i = axes.size(); i > 0; --i) {
const IterVar& axis = axes[i - 1];
const Var& loop_var = Downcast<Var>(bindings[i - 1]);
body = For(loop_var, axis->dom->min, axis->dom->extent, ForKind::kSerial, body);
}

return body;
}

Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* info) {
// Step 1. Check all inputs are visited before and update var_map.
std::unordered_map<const VarNode*, PrimExpr> var_map;
ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size());
for (size_t i = 0; i < extern_op->inputs.size(); ++i) {
const Buffer& placeholder = extern_op->input_placeholders[i];
const te::Tensor& input_tensor = extern_op->inputs[i];
auto it = info->tensor2buffers.find(input_tensor);
ICHECK(it != info->tensor2buffers.end());
var_map[placeholder->data.get()] = it->second->data;
}

// Step 2. Update info with its output tensor and placeholder buffer.
ICHECK_EQ(extern_op->num_outputs(), extern_op->output_placeholders.size());
for (int i = 0; i < extern_op->num_outputs(); ++i) {
const Buffer& placeholder = extern_op->output_placeholders[i];
const te::Tensor& output_tensor = extern_op.output(i);
info->tensor2buffers[output_tensor] = placeholder;
if (!info->IsArg(output_tensor)) {
info->root_alloc.push_back(placeholder);
}
}

// Step 3. Collect Access Region
Array<BufferRegion> reads, writes;
for (const te::Tensor& tensor : extern_op->inputs) {
// We have ICHECK before so it is not needed here.
reads.push_back(BufferRegion::FullRegion(info->tensor2buffers[tensor]));
}
for (const Buffer& buffer : extern_op->output_placeholders) {
writes.push_back(BufferRegion::FullRegion(buffer));
}

Stmt body = Substitute(extern_op->body, var_map);

// Step 4. Generate opaque block as body.
return BlockRealize(/*iter_values=*/{},
/*predicate=*/Bool(true),
/*block=*/
Block(/*iter_vars=*/{},
/*reads=*/std::move(reads),
/*writes=*/std::move(writes),
/*name_hint=*/info->GetUniqueName(extern_op->name),
/*body=*/std::move(body),
/*init=*/NullOpt,
/*alloc_buffers=*/{},
/*match_buffers=*/{},
/*annotations=*/extern_op->attrs));
}

/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
// Step 1. Create tensor read graph.
Array<te::Operation> arg_ops;
for (const te::Tensor& arg : arg_list) {
arg_ops.push_back(arg->op);
}
te::ReadGraph g = te::CreateReadGraph(arg_ops);
Array<te::Operation> order = te::PostDFSOrder(arg_ops, g);

// Step 2. Checking all Operations are supported.
for (const te::Operation& op : order) {
if (!(op->IsInstance<te::PlaceholderOpNode>() || op->IsInstance<te::ComputeOpNode>() ||
op->IsInstance<te::ExternOpNode>()))
LOG(FATAL) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". "
<< "Only te.placeholder and te.compute are allowed for now.";
}

// Infomations used in CreatePrimFunc and its sub-funtions.
CreateFuncInfo info(arg_list);
// Root body stmts.
Array<Stmt> root_stmts;

// Step 3. Rewrite compute stages into blocks.
for (const te::Operation& op : order) {
if (const auto* placeholder = op.as<te::PlaceholderOpNode>()) {
// Case 1. PlaceholderOp (te.placeholder)
ICHECK_EQ(op->num_outputs(), 1);
const te::Tensor& tensor = op.output(0);
// Check op is in op list
ICHECK(info.IsArg(tensor));
const Buffer& buffer = decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name);
info.tensor2buffers[tensor] = buffer;
} else if (const auto* compute_op = op.as<te::ComputeOpNode>()) {
// Case 2. ComputeOp (te.compute)
root_stmts.push_back(GenerateStmtFromCompute(GetRef<te::ComputeOp>(compute_op), &info));
} else if (const auto extern_op = op.as<te::ExternOpNode>()) {
// Case 3. ExternOp (te.extern)
root_stmts.push_back(GenerateStmtFromExternOp(GetRef<te::ExternOp>(extern_op), &info));
} else {
ICHECK(false) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". "
<< "Only te.placeholder and te.compute are allowed for now.";
}
}

// Step 4. Create func and complete it.
Array<Var> parameters;
Map<Var, Buffer> buffer_map;
for (const te::Tensor& tensor : arg_list) {
Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle()));
parameters.push_back(arg);
auto it = info.tensor2buffers.find(tensor);
ICHECK(it != info.tensor2buffers.end());
buffer_map.Set(arg, it->second);
}
PrimFunc func = PrimFunc(/*params=*/std::move(parameters),
/*body=*/SeqStmt::Flatten(root_stmts),
/*ret_type=*/VoidType(),
/*buffer_map=*/std::move(buffer_map));

const auto* complete = runtime::Registry::Get("script.Complete");
ICHECK(complete);

return (*complete)(func, info.root_alloc);
} // namespace tir

TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed([](const Array<te::Tensor>& tensors) {
return CreatePrimFunc(tensors);
});

} // namespace tir
} // namespace tvm
Loading

0 comments on commit 4122a6a

Please sign in to comment.