Skip to content

Commit

Permalink
[TensorIR][M1c] LCA detector (apache#7848)
Browse files Browse the repository at this point in the history
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
  • Loading branch information
3 people authored Apr 16, 2021
1 parent cc79e8f commit 6aefc26
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 1 deletion.
9 changes: 9 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,15 @@ Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
*/
TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);

/*!
* \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level
* access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).
* The LCA may be a For loop or a Block.
* \param func The PrimFunc to be detected.
* \return The Map from buffer to the LCA of all access to it.
*/
TVM_DLL Map<Buffer, Stmt> DetectBufferAccessLCA(const PrimFunc& func);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
Expand Down
22 changes: 21 additions & 1 deletion python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# under the License.
"""Wrapping existing analysis utils."""
# pylint: disable=invalid-name

from typing import Dict
from . import _ffi_api
from ..function import PrimFunc
from .. import Buffer, Stmt


def expr_deep_equal(lhs, rhs):
Expand Down Expand Up @@ -129,3 +131,21 @@ def get_block_access_region(block, buffer_var_map):
- third: opaque regions
"""
return _ffi_api.get_block_access_region(block, buffer_var_map)


def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
"""Detect the lowest common ancestor(LCA) of buffer access, including both high-level
access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).
The LCA may be a For loop or a Block.
Parameters
----------
func: tvm.tir.PrimFunc
The function to be detected.
Returns
-------
result : Dict[Buffer, Stmt]
Map from buffer to the LCA of all access to it.
"""
return _ffi_api.detect_buffer_access_lca(func) # pylint: disable=no-member
173 changes: 173 additions & 0 deletions src/tir/analysis/buffer_access_lca_detector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tir/analysis/buffer_access_lca_detector.cc
* \brief Detect the lowest common ancestor(LCA) of buffer access
*/

#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>

#include "../../support/arena.h"

namespace tvm {
namespace tir {

/*!
* \brief Detect the lowest common ancestor(LCA) position of Buffer access.
* \note Only consider BlockNode and ForNode to be the LCA nodes.
*/
class LCADetector : public StmtExprVisitor {
public:
static Map<Buffer, Stmt> Detect(const PrimFunc& func) {
LCADetector detector;
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get());
}
detector(func->body);
// Prepare the return
Map<Buffer, Stmt> buffer_lca;
for (const auto& kv : detector.buffer_lca_) {
buffer_lca.Set(GetRef<Buffer>(kv.first), GetRef<Stmt>(kv.second->stmt));
}
return buffer_lca;
}

private:
/*!
* \brief The AST node information for querying LCA.
* \note Only BlockNode and ForNode are considered, since they are the only statements whose
* body can be a SeqStmt (the LCA of buffer access) in TensorIR.
*/
struct ScopeInfo {
// The parent scope info
const ScopeInfo* parent_scope_info;
// The parent scope stmt node
const StmtNode* stmt;
// The scope depth in the AST
int depth;
ScopeInfo(const ScopeInfo* parent_info, const StmtNode* stmt, int depth)
: parent_scope_info(parent_info), stmt(stmt), depth(depth) {}
};

void VisitStmt_(const ForNode* op) final {
int n = ancestor_scopes_.size();
const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);
ancestor_scopes_.push_back(current_scope);
StmtExprVisitor::VisitStmt_(op);
ancestor_scopes_.pop_back();
}

void VisitStmt_(const BlockNode* op) final {
int n = ancestor_scopes_.size();
for (const Buffer& buf : op->alloc_buffers) {
buffer_var_map_.emplace(buf->data.get(), buf.get());
}
const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);
ancestor_scopes_.push_back(current_scope);
StmtExprVisitor::VisitStmt_(op);
ancestor_scopes_.pop_back();
}

void VisitExpr_(const BufferLoadNode* op) final {
UpdateBufferLCA(op->buffer.get());
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) final {
UpdateBufferLCA(op->buffer.get());
StmtExprVisitor::VisitStmt_(op);
}

void VisitStmt_(const BufferRealizeNode* op) final {
buffer_var_map_.emplace(op->buffer->data.get(), op->buffer.get());
StmtExprVisitor::VisitStmt_(op);
}

// Works for Load/Store and opaque access.
void VisitExpr_(const VarNode* op) final { VisitBufferVar(op); }

// Explict to visit buffer data in Load and Store node.
void VisitExpr_(const LoadNode* op) final {
ExprVisitor::VisitExpr_(op);
VisitBufferVar(op->buffer_var.get());
}

void VisitStmt_(const StoreNode* op) final {
StmtVisitor::VisitStmt_(op);
VisitBufferVar(op->buffer_var.get());
}

void VisitBufferVar(const VarNode* op) {
auto it = buffer_var_map_.find(op);
if (it != buffer_var_map_.end()) {
UpdateBufferLCA(it->second);
}
}

void UpdateBufferLCA(const BufferNode* buffer) {
const ScopeInfo*& lca = buffer_lca_[buffer];
lca = LowestCommonAncestor(lca, ancestor_scopes_.back());
}

static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) {
ICHECK(lhs || rhs);
if (lhs == nullptr) return rhs;
if (rhs == nullptr) return lhs;
while (lhs->parent_scope_info != nullptr && //
rhs->parent_scope_info != nullptr && //
lhs != rhs) {
if (lhs->depth == rhs->depth) {
lhs = lhs->parent_scope_info;
rhs = rhs->parent_scope_info;
} else if (lhs->depth < rhs->depth) {
rhs = rhs->parent_scope_info;
} else {
lhs = lhs->parent_scope_info;
}
}
if (lhs->parent_scope_info == nullptr) {
return lhs;
}
if (rhs->parent_scope_info == nullptr) {
return rhs;
}
ICHECK(lhs == rhs);
return lhs;
}

/*! \brief The ancestor scope stacks info (Block and For), initialized with Null. */
std::vector<const ScopeInfo*> ancestor_scopes_ = {nullptr};
/*! \brief The map from Buffer to its LCA ForNode/BlockNode. */
std::unordered_map<const BufferNode*, const ScopeInfo*> buffer_lca_ = {};
/*! \brief The map from Buffer data to the Buffer. */
std::unordered_map<const VarNode*, const BufferNode*> buffer_var_map_ = {};
/*! \brief Internal arena. */
support::Arena arena_;
};

Map<Buffer, Stmt> DetectBufferAccessLCA(const PrimFunc& func) { return LCADetector::Detect(func); }

TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca").set_body_typed(DetectBufferAccessLCA);
} // namespace tir
} // namespace tvm
107 changes: 107 additions & 0 deletions tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import tir
from tvm.script import ty


@tvm.script.tir
def buffer_load_store_func(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128), "float32")
B = tir.match_buffer(b, (128, 128), "float32")
C = tir.alloc_buffer((128, 128), "float32")
D = tir.alloc_buffer((128, 128), "float32")
with tir.block([128, 128]) as [i, j]:
A[i, j] = tir.float32(0)
with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]:
with tir.init():
for ii, jj in tir.grid(4, 4):
B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
for ii, jj in tir.grid(4, 4):
for kk in range(0, 4):
B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk]
for kk in range(0, 4):
B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]


@tvm.script.tir
def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None:
B = tir.match_buffer(b, [16, 16], "float32")
C = tir.match_buffer(c, [16, 16], "float32")

with tir.block([]):
tir.reads([])
tir.writes(B[0:16, 0:16])
A = tir.allocate([256], "float32", "global")
for i, j in tir.grid(16, 16):
tir.store(A, i * 16 + j, 1)
for i in range(0, 16):
for j in range(0, 16):
tir.evaluate(tir.load("float32", A, i * 16 + j))
for j in range(0, 16):
tir.evaluate(
tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, tir.float32(0), dtype="handle")
)

for i, j in tir.grid(16, 16):
with tir.block([16, 16]) as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
C[vi, vj] = B[vi, vj]


def test_buffer_load_store():
func = buffer_load_store_func
A, B = [func.buffer_map[x] for x in func.params]
C, D = func.body.block.alloc_buffers
lca = tir.analysis.detect_buffer_access_lca(func)

# LCA of Buffer A is root
root_block = func.body.block
assert lca[A] == func.body.block

# LCA of Buffer B is reduction block
reduce_block = root_block.body[1].body.body.body.block
assert lca[B] == reduce_block

# LCA of Buffer C is the second loop kk
loop_jj = reduce_block.body.body
assert lca[C] == loop_jj

# LCA of Buffer D is loop jj
loop_kk = loop_jj.body[1]
assert lca[D] == loop_kk


def test_opaque_access():
func = buffer_opaque_access
B, C = [func.buffer_map[x] for x in func.params]
lca = tir.analysis.detect_buffer_access_lca(func)

# Cannot detect buffer A since it is define by low-level Allocate

# LCA of Buffer B is root
root_block = func.body.block
assert lca[B] == func.body.block

# LCA of Buffer C is the correspond block
assert lca[C] == root_block.body[1].body.body.block


if __name__ == "__main__":
test_buffer_load_store()
test_opaque_access()

0 comments on commit 6aefc26

Please sign in to comment.