Skip to content

Commit

Permalink
[Refactor] Enforce attaching storage scope to PointerType (apache#8366)
Browse files Browse the repository at this point in the history
* Add storage scope to ProducerRealize, always create a buffer with scope

* update schedule_ops.cc

* update schedule_postproc_to_primfunc.cc

* restore more realize_scope

This reverts commit b66c3ba.

* make the default scope be "" instead of None in ir builder

* restore realize_scope visit in storage_flatten.cc

* update storage_access.cc

* make sure buffer var is of PointerType in ir builder

This reverts commit e650b6c.

* enforce default storage scope of global

* added remap pass but does not work yet

* fixed all reduce issue

This reverts commit 8e20003.

* simplify

* trying mitigation for aot test

* merge remaining changes from initial branch

* remove use of attr::storage_scope from codegen

* restore a visit to AttrStmt with attr::storage_scope in storage_rewrite

* disable check

* lint fix

* revert default scope to ""

* format

* fix volatile access to shared mem in lower all reduce

* fixed gpu coorporative load/store test

* pass storage scope to PointerType in tvm script parser

This reverts commit 99cfb9d18781dcfdea169d920450f9063ab18b6b.

* fixed tvmscript roundtrip test

* fixed tir flatten buffer test

* fixed test_tir_transform_hoist_if.py

* use storage scope global by default in aot_executor_codegen.cc

* add missing default storage scope in create_primfunc.cc

* restore StorageInfo struct in llvm backend

* UpdateStorageScope -> WithStorageScope

* fixed lower warp memory test

* GetStorageScope -> GetPtrStorageScope

* Enable storage scope invariant check in AttrStmt constructor

* remove GetPtrStorageScope and WithStorageScope from public header

* move RemapStorageScope to its own file

* add more method to RemapStorageScope

* update lower_thread_allreduce to use RemapStorageScope

* RemapStorageScope -> UpdatePointerStorageScope

* remove realize_scope from hybrid script

* removed realize_scope in schedule_ops

* remove realize_scope from schedule_postproc_to_primfunc

* remove remaining realize_scope usage from schedule_ops.cc

* remove realize_scope usage from storage_flatten.cc

* fixed test_tir_transform_lower_warp_memory.py following realize_scope removal

* Add storage scope to ProducerRealize, always create a buffer with scope

* update schedule_ops.cc

* update schedule_postproc_to_primfunc.cc

* restore more realize_scope

This reverts commit b66c3ba.

* make the default scope be "" instead of None in ir builder

* restore realize_scope visit in storage_flatten.cc

* update storage_access.cc

* make sure buffer var is of PointerType in ir builder

This reverts commit e650b6c.

* enforce default storage scope of global

* added remap pass but does not work yet

* fixed all reduce issue

This reverts commit 8e20003.

* simplify

* trying mitigation for aot test

* merge remaining changes from initial branch

* remove use of attr::storage_scope from codegen

* restore a visit to AttrStmt with attr::storage_scope in storage_rewrite

* disable check

* lint fix

* revert default scope to ""

* format

* fix volatile access to shared mem in lower all reduce

* fixed gpu coorporative load/store test

* pass storage scope to PointerType in tvm script parser

This reverts commit 99cfb9d18781dcfdea169d920450f9063ab18b6b.

* fixed tvmscript roundtrip test

* fixed tir flatten buffer test

* fixed test_tir_transform_hoist_if.py

* use storage scope global by default in aot_executor_codegen.cc

* add missing default storage scope in create_primfunc.cc

* restore StorageInfo struct in llvm backend

* UpdateStorageScope -> WithStorageScope

* fixed lower warp memory test

* GetStorageScope -> GetPtrStorageScope

* Enable storage scope invariant check in AttrStmt constructor

* remove GetPtrStorageScope and WithStorageScope from public header

* move RemapStorageScope to its own file

* add more method to RemapStorageScope

* update lower_thread_allreduce to use RemapStorageScope

* RemapStorageScope -> UpdatePointerStorageScope

* remove realize_scope from hybrid script

* removed realize_scope in schedule_ops

* remove realize_scope from schedule_postproc_to_primfunc

* remove remaining realize_scope usage from schedule_ops.cc

* remove realize_scope usage from storage_flatten.cc

* fixed test_tir_transform_lower_warp_memory.py following realize_scope removal

* Address comments

* Remove blank line diff

Co-authored-by: Masahiro Masuda <masahi@[email protected]>
Co-authored-by: masa <[email protected]>
  • Loading branch information
3 people authored Jul 13, 2021
1 parent f62917e commit 1a26733
Show file tree
Hide file tree
Showing 53 changed files with 386 additions and 196 deletions.
2 changes: 0 additions & 2 deletions docs/dev/inferbound.rst
Original file line number Diff line number Diff line change
Expand Up @@ -447,13 +447,11 @@ Here is the IR after ScheduleOps (note that loops with extent 1 have been preser

::

// attr [compute(D, 0x2c070b0)] realize_scope = ""
realize D([0, 4], [0, 5], [0, 16]) {
produce D {
for (di, 0, 4) {
for (dj, 0, 5) {
for (dk, 0, 16) {
// attr [compute(C, 0x2c29990)] realize_scope = ""
realize C([dj, 1], [dk, 1]) {
produce C {
for (i, 0, 1) {
Expand Down
15 changes: 8 additions & 7 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,12 @@ class TVM_DLL OperationNode : public Object {
* \param stage the op's stage.
* \param realize_map The realization domain map of the operators.
* \param body The body that is going to get
* \param storage_scope The storage scope associated with this realization
* \return A realization statement that wraps body.
*/
virtual Stmt BuildRealize(const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const = 0;
const std::unordered_map<IterVar, Range>& realize_map, const Stmt& body,
String storage_scope = "") const = 0;
/*!
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
Expand Down Expand Up @@ -168,7 +169,7 @@ class PlaceholderOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down Expand Up @@ -212,7 +213,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
virtual size_t num_schedulable_dims() const = 0;

static constexpr const char* _type_key = "BaseComputeOp";
Expand Down Expand Up @@ -370,7 +371,7 @@ class ScanOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down Expand Up @@ -433,7 +434,7 @@ class ExternOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down Expand Up @@ -498,7 +499,7 @@ class HybridOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down
3 changes: 2 additions & 1 deletion include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,13 @@ class Buffer : public ObjectRef {
* \param shape The shape of the buffer,
* \param dtype The content data type.
* \param name The name of the buffer
* \param storage_scope The storage scope associated with this buffer
* \param span The location of this object in the source code.
* \return The created buffer.
* \sa Buffer for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
String name = "buffer", Span span = Span());
String name = "buffer", String storage_scope = "", Span span = Span());

/*!
* \brief Base node for data producers.
Expand Down
9 changes: 7 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,25 +464,30 @@ class ProducerRealizeNode : public StmtNode {
PrimExpr condition;
/*! \brief The body of realization. */
Stmt body;
/*! \brief The storage scope associated with this realization. */
String storage_scope;

void VisitAttrs(AttrVisitor* v) {
v->Visit("producer", &producer);
v->Visit("bounds", &bounds);
v->Visit("condition", &condition);
v->Visit("body", &body);
v->Visit("storage_scope", &storage_scope);
v->Visit("span", &span);
}

bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const {
return equal(producer, other->producer) && equal(bounds, other->bounds) &&
equal(condition, other->condition) && equal(body, other->body);
equal(condition, other->condition) && equal(body, other->body) &&
equal(storage_scope, other->storage_scope);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(producer);
hash_reduce(bounds);
hash_reduce(condition);
hash_reduce(body);
hash_reduce(storage_scope);
}

static constexpr const char* _type_key = "tir.ProducerRealize";
Expand All @@ -496,7 +501,7 @@ class ProducerRealizeNode : public StmtNode {
class ProducerRealize : public Stmt {
public:
TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
Span span = Span());
String storage_scope = "", Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
};
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def enter_scope(

def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None):
"""Setup buffer var for a given type."""
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope)
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)

setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/script/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,22 @@ def var(dtype, span):
super().__init__(var, def_symbol=True)


@register
class BufferVarDef(SpecialStmt):
"""Special function for defining a variable of pointer type"""

def __init__(self):
def buffer_var(dtype, storage_scope, span):
assert isinstance(
self.node, ast.Assign
), f"BufferVarDef expected ast.Assign but got {type(self.node)}"
ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
v = te.var(self.node.lhs.id.name, ptr_type, span=span)
self.context.update_symbol(v.name, v, self.node)

super().__init__(buffer_var, def_symbol=True)


@register
class EnvThread(SpecialStmt):
"""Bind a var to thread env"""
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ def wrap_up_realize(self, node, body):
_domain = [Range.from_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype
_true = tvm.runtime.convert(True)
body = tvm.tir.ProducerRealize(_buf, _domain, _true, body)
body = tvm.tir.AttrStmt(_buf.op, "realize_scope", tvm.runtime.convert(_scope), body)
body = tvm.tir.ProducerRealize(_buf, _domain, _true, body, tvm.runtime.convert(_scope))

for elem in to_pop:
self.symbols.pop(elem)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def decl_buffer(
# Bool is represented as uint1 in the IR, but stored as int8
storage_type = PrimType(dtype)
storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type
data = Var(name, PointerType(storage_type), span)
data = Var(name, PointerType(storage_type, scope), span)
return _ffi_api.Buffer( # type: ignore
data,
dtype,
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def let(self, var_name, value):
self.emit(lambda x: _stmt.LetStmt(var, value, x))
return var

def allocate(self, dtype, shape, name="buf", scope=None):
def allocate(self, dtype, shape, name="buf", scope=""):
"""Create a allocate statement.
Parameters
Expand All @@ -416,15 +416,15 @@ def allocate(self, dtype, shape, name="buf", scope=None):
buffer : BufferVar
The buffer var representing the buffer.
"""
buffer_var = _expr.Var(name, PointerType(PrimType(dtype)))
buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope))
if not isinstance(shape, (list, tuple, _container.Array)):
shape = [shape]
if scope:
self.scope_attr(buffer_var, "storage_scope", scope)
self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x))
return BufferVar(self, buffer_var, shape, dtype)

def pointer(self, content_type, name="ptr"):
def pointer(self, content_type, name="ptr", scope=""):
"""Create pointer variable with content type.
Parameters
Expand All @@ -435,12 +435,15 @@ def pointer(self, content_type, name="ptr"):
name : str, optional
The name of the pointer.
scope : str, optional
The scope of the pointer.
Returns
-------
ptr : BufferVar
The buffer var representing the buffer.
"""
buffer_var = _expr.Var(name, dtype="handle")
buffer_var = _expr.Var(name, PointerType(PrimType(content_type), scope))
return BufferVar(self, buffer_var, None, content_type)

def buffer_ptr(self, buf, shape=None):
Expand Down
13 changes: 11 additions & 2 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,22 @@ class ProducerRealize(Stmt):
body : Stmt
The realize body
storage_scope : str
The storage scope associated with this realization
span : Optional[Span]
The location of this itervar in the source code.
"""

def __init__(self, producer, bounds, condition, body, span=None):
def __init__(self, producer, bounds, condition, body, storage_scope="", span=None):
self.__init_handle_by_constructor__(
_ffi_api.ProducerRealize, producer, bounds, condition, body, span # type: ignore
_ffi_api.ProducerRealize,
producer,
bounds,
condition,
body,
storage_scope,
span, # type: ignore
)


Expand Down
9 changes: 2 additions & 7 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,6 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) {
indent_ += tab_;
PrintStmt(op->body);
indent_ -= tab_;
} else if (op->attr_key == tir::attr::realize_scope) {
auto v = Downcast<Operation>(op->node);
alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value;
PrintStmt(op->body);
} else {
// For now we ignore the unsupported AttrStmt
PrintStmt(op->body);
Expand All @@ -327,8 +323,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) {

void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) {
auto tensor = Downcast<Tensor>(op->producer);
ICHECK(alloc_storage_scope_.count(tensor->op));
if (!alloc_storage_scope_[tensor->op].empty()) {
if (!op->storage_scope.empty()) {
PrintIndent();
stream << GetTensorID(tensor) << " = allocate((";
for (size_t i = 0; i < op->bounds.size(); ++i) {
Expand All @@ -339,7 +334,7 @@ void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) {
stream << "), '";
PrintType(tensor->dtype, stream);
stream << "', '";
stream << alloc_storage_scope_[tensor->op] << "')\n";
stream << op->storage_scope << "')\n";
}
PrintStmt(op->body);
}
Expand Down
2 changes: 0 additions & 2 deletions src/contrib/hybrid/codegen_hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
* \param tensor The tensor to allocate a name.
*/
std::string GetTensorID(const Tensor& tensor);
/*! \brief the storage scope of allocation */
std::map<Operation, std::string> alloc_storage_scope_;
};

} // namespace contrib
Expand Down
14 changes: 12 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/ir/module.h>
#include <tvm/node/serialization.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/function.h>
Expand Down Expand Up @@ -1013,8 +1014,17 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
return memo_var_[GetRef<Var>(a)].str() < memo_var_[GetRef<Var>(b)].str();
});
for (const auto& var : vars) {
header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = tir.var(";
header_var << PrintDType(var->dtype) << ")";
auto type = GetRef<Var>(var)->type_annotation;
if (auto* ptr_type = type.as<PointerTypeNode>()) {
auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
ICHECK(prim_type);
header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = tir.buffer_var(";
header_var << PrintDType(prim_type->dtype) << ", "
<< Doc::StrLiteral(ptr_type->storage_scope) << ")";
} else {
header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = tir.var(";
header_var << PrintDType(var->dtype) << ")";
}
}
}
doc << Doc::Indent(4, header_attr << header_var << header_buf << body);
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,8 @@ class AOTExecutorCodegen : public ExprVisitor {
// Define the storage allocator ids
for (auto kv : storage_device_map_) {
for (auto sid : kv.second->storage_ids) {
te::Var buffer_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8))));
te::Var buffer_var(MakeString("sid_", sid),
PointerType(PrimType(DataType::Int(8)), "global"));
sids_table_[sid] = buffer_var;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/thread_storage_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ struct StorageScope {
*/
static StorageScope Create(const std::string& s) {
StorageScope r;
if (s.compare(0, 6, "global") == 0) {
if (s.empty()) {
r.rank = StorageRank::kGlobal;
} else if (s.compare(0, 6, "global") == 0) {
r.rank = StorageRank::kGlobal;
r.tag = s.substr(6, std::string::npos);
} else if (s.compare(0, 6, "shared") == 0) {
Expand Down
5 changes: 3 additions & 2 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class CodeGenAMDGPU : public CodeGenLLVM {
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == runtime::StorageRank::kLocal) {
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (storage_scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
Expand All @@ -99,7 +100,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}
buf = alloca;
} else {
ICHECK(info.scope.rank == runtime::StorageRank::kShared)
ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
Expand Down
8 changes: 2 additions & 6 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp
auto it = alloc_storage_info_.find(buf_var);
if (it != alloc_storage_info_.end()) {
const StorageInfo& info = it->second;
*p_native_bits = NativeVectorBits(info.scope);
*p_native_bits =
NativeVectorBits(runtime::StorageScope::Create(GetPtrStorageScope(GetRef<Var>(buf_var))));
max_align_bits = info.alignment * 8;
} else {
*p_native_bits = native_vector_bits_;
Expand Down Expand Up @@ -1390,11 +1391,6 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
}
}
} else if (op->attr_key == tir::attr::storage_scope) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
alloc_storage_info_[v].scope =
runtime::StorageScope::Create(op->value.as<StringImmNode>()->value);
} else if (op->attr_key == tir::attr::storage_alignment) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
Expand Down
2 changes: 0 additions & 2 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
protected:
/*! \brief The storage information */
struct StorageInfo {
/*! \brief The storage scope */
runtime::StorageScope scope;
/*! \brief The alignment of allocation */
int alignment{0};
};
Expand Down
Loading

0 comments on commit 1a26733

Please sign in to comment.