Skip to content

Commit

Permalink
[PASS] InjectDoubleBuffer (apache#405)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Sep 1, 2017
1 parent b8c8aad commit a45d3b0
Show file tree
Hide file tree
Showing 18 changed files with 421 additions and 14 deletions.
8 changes: 8 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ constexpr const char* pragma_scope = "pragma_scope";
* run prefetch of Tensor on the current loop scope
*/
constexpr const char* prefetch_scope = "prefetch_scope";
/*!
* \brief Marks production of double buffer data
*/
constexpr const char* double_buffer_scope = "double_buffer_scope";
/*!
* \brief Marks region used by double buffer write
*/
constexpr const char* double_buffer_write = "double_buffer_write";
/*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ Stmt InjectVirtualThread(Stmt stmt);
*/
Stmt InjectPrefetch(Stmt stmt);

/*!
* \brief Inject double buffer into stmt.
* \param stmt The statment to be transformed.
* \param split_loop Whether split the loop containing double buffering.
* \return Transformed stmt.
*/
Stmt InjectDoubleBuffer(Stmt stmt, bool split_loop);

/*!
* \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope.
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ class Stage : public NodeRef {
* \return reference to self
*/
Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
/*!
* \brief Compute current stage with double buffering.
* \return reference to self.
*/
Stage& double_buffer(); // NOLINT(*)
/*!
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
Expand Down Expand Up @@ -408,6 +413,8 @@ class StageNode : public Node {
std::string scope;
/*! \brief Whether this is an output stage */
bool is_output{false};
/*! \brief Whether apply double buffer optimization to this stage */
bool double_buffer{false};
/*!
* \brief The parent group of the current stage.
* The stage cannot be assigned to stages outside the group.
Expand All @@ -429,6 +436,7 @@ class StageNode : public Node {
v->Visit("attach_stage", &attach_stage);
v->Visit("scope", &scope);
v->Visit("is_output", &is_output);
v->Visit("double_buffer", &double_buffer);
v->Visit("group", &group);
v->Visit("num_child_stages", &num_child_stages);
}
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class BuildConfig(object):
"offset_factor": 0,
"data_alignment": -1,
"restricted_func": True,
"double_buffer_split_loop": True,
"add_lower_pass": None
}
def __init__(self, **kwargs):
Expand Down Expand Up @@ -97,6 +98,10 @@ def build_config(**kwargs):
not to overlap. This enables more optimization.
Corresponds to restricted keyword in C99
double_buffer_split_loop: bool, default=True
Whether split the loop containing double buffer so
that the buffer fetching won't contain condition.
add_lower_pass: list of function(Stmt->Stmt), default=None
Additional lowering passes to be applied before make_api.
Expand Down Expand Up @@ -187,6 +192,7 @@ def lower(sch,
Then the Stmt before make api is returned.
"""
binds, arg_list = get_binds(args, binds)
cfg = BuildConfig.current
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
Expand All @@ -198,8 +204,8 @@ def lower(sch,
stmt = ir_pass.LoopPartition(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
stmt = ir_pass.StorageRewrite(stmt)
cfg = BuildConfig.current
stmt = ir_pass.UnrollLoop(
stmt,
cfg.auto_unroll_max_step,
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,21 @@ def _exit_cb():
self.emit(_make.IfThenElse(prev.condition, prev.then_case, self._pop_seq()))
return WithScope(None, _exit_cb)

def new_scope(self):
"""Create new scope,
this is useful to set boundary of attr and allocate.
Returns
-------
new_scope : WithScope
The result new scope.
"""
self._seq_stack.append([])
def _exit_cb():
self.emit(self._pop_seq())
return WithScope(None, _exit_cb)

def allocate(self, dtype, shape, name="buf", scope=None):
"""Create a allocate statement.
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,4 +589,13 @@ def storage_align(self, axis, factor, offset):
"""
_api_internal._StageStorageAlign(self, axis, factor, offset)

def double_buffer(self):
"""Compute the current stage via double buffering.
This can only be applied to intermediate stage.
This will double the storage cost of the current stage.
Can be useful to hide load latency.
"""
_api_internal._StageDoubleBuffer(self)

_init_api("tvm.schedule")
9 changes: 7 additions & 2 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,13 +385,18 @@ TVM_REGISTER_API("_StagePragma")
TVM_REGISTER_API("_StagePrefetch")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage()
.prefetch(args[1], args[2], args[3]);
.prefetch(args[1], args[2], args[3]);
});

TVM_REGISTER_API("_StageStorageAlign")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage()
.storage_align(args[1], args[2], args[3]);
.storage_align(args[1], args[2], args[3]);
});

TVM_REGISTER_API("_StageDoubleBuffer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage().double_buffer();
});

TVM_REGISTER_API("_ScheduleNormalize")
Expand Down
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ REGISTER_PASS1(CoProcSync);
REGISTER_PASS1(LowerStorageAccessInfo);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS2(InjectDoubleBuffer);
REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline);
Expand Down
226 changes: 226 additions & 0 deletions src/pass/inject_double_buffer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*!
* Copyright (c) 2017 by Contributors
*
* \brief Inject double buffering optimization for data fetch.
* \file inject_double_buffer.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace ir {

// Detect double buffer variables.
class DoubleBufferDetector : public IRVisitor {
public:
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::double_buffer_scope) {
touched_.insert(op->node.as<Variable>());
IRVisitor::Visit_(op);
} else {
IRVisitor::Visit_(op);
}
}

void Visit_(const Variable* op) final {
if (touched_.count(op)) {
touched_.erase(op);
}
}
// The set of touched variable.
std::unordered_set<const Variable*> touched_;
};

class DoubleBufferInjector : public IRMutator {
public:
explicit DoubleBufferInjector(bool split_loop)
: split_loop_(split_loop) {}

Stmt Inject(const Stmt& stmt) {
DoubleBufferDetector detector;
detector.Visit(stmt);
if (detector.touched_.empty()) return stmt;
for (const Variable* v : detector.touched_) {
dbuffer_info_[v] = StorageEntry();
}
return ConvertSSA(this->Mutate(stmt));
}

Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
auto it = dbuffer_info_.find(buf);
if (it != dbuffer_info_.end()) {
it->second.scope = op->value.as<StringImm>()->value;
return Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
} else if (op->attr_key == attr::double_buffer_scope) {
return MakeProducer(op, s);
} else {
return IRMutator::Mutate_(op, s);
}
}

Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
it->second.size = arith::ComputeReduce<Mul>(op->extents);
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
for (Expr e : op->extents) {
new_extents.push_back(e);
}
CHECK(it->second.loop != nullptr);
auto& alloc_nest = loop_allocs_[it->second.loop];
alloc_nest.emplace_back(AttrStmt::make(
op->buffer_var, attr::storage_scope,
StringImm::make(it->second.scope),
Evaluate::make(0)));
alloc_nest.emplace_back(Allocate::make(
op->buffer_var, op->type, new_extents, op->condition,
Evaluate::make(0)));
return op->body;
} else {
return IRMutator::Mutate_(op, s);
}
}

Stmt Mutate_(const For* op, const Stmt& s) final {
loop_nest_.push_back(op);
Stmt stmt = IRMutator::Mutate_(op, s);
auto it = loop_pre_.find(op);
if (it != loop_pre_.end()) {
const For* old_loop = stmt.as<For>();
if (split_loop_) {
Expr new_ext = arith::ComputeExpr<Sub>(
old_loop->extent, make_const(old_loop->loop_var.type(), 1));
Stmt loop = For::make(
old_loop->loop_var, old_loop->min, new_ext,
old_loop->for_type, old_loop->device_api,
old_loop->body);
std::unordered_map<const Variable*, Expr> vmap;
vmap[old_loop->loop_var.get()] = new_ext;
Stmt end = Substitute(old_loop->body, vmap);
stmt = Block::make(loop, end);
}
stmt = Block::make(MergeSeq(it->second), stmt);
}
it = loop_allocs_.find(op);
if (it != loop_allocs_.end()) {
stmt = MergeNest(it->second, stmt);
}
loop_nest_.pop_back();
return stmt;
}

Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
CHECK(in_double_buffer_scope_);
CHECK(e.size.defined());
return Store::make(op->buffer_var,
op->value,
e.switch_write_var * e.size + op->index,
op->predicate);
} else {
return stmt;
}
}

Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
CHECK(e.size.defined());
CHECK(e.switch_read_var.defined());
return Load::make(op->type,
op->buffer_var,
e.switch_read_var * e.size + op->index,
op->predicate);
} else {
return expr;
}
}

Expr Mutate_(const Variable* op, const Expr& e) final {
CHECK(!dbuffer_info_.count(op));
return e;
}

private:
Stmt MakeProducer(const AttrStmt* op, const Stmt& s) {
const VarExpr buffer(op->node.node_);
CHECK_NE(loop_nest_.size(), 0U)
<< "Double buffer scope must be inside a loop";
auto it = dbuffer_info_.find(buffer.get());
if (it == dbuffer_info_.end()) {
LOG(WARNING) << "Skip double buffer scope " << op->node;
return Mutate(op->body);
}
StorageEntry& e = it->second;
e.loop = loop_nest_.back();
Expr zero = make_const(e.loop->loop_var.type(), 0);
Expr one = make_const(e.loop->loop_var.type(), 1);
Expr two = make_const(e.loop->loop_var.type(), 2);
Expr loop_shift = e.loop->loop_var + one;
e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db",
e.loop->loop_var.type());
e.switch_read_var = e.loop->loop_var % two;
in_double_buffer_scope_ = true;
Stmt body = Mutate(op->body);
in_double_buffer_scope_ = false;
std::unordered_map<const Variable*, Expr> vmap;
vmap[e.switch_write_var.get()] = zero;
vmap[e.loop->loop_var.get()] = zero;
loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
vmap[e.loop->loop_var.get()] = loop_shift;
vmap[e.switch_write_var.get()] = loop_shift % two;
body = Substitute(body, vmap);
body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body);
body = IfThenElse::make(loop_shift < e.loop->extent, body);
return body;
}
// Storage entry for those who need double buffering.
struct StorageEntry {
// The size of the buffer
Expr size;
// The loop we need
const For* loop{nullptr};
// The switch variable.
VarExpr switch_write_var;
// The switch variable for reading.
Expr switch_read_var;
// The storage scope.
std::string scope;
};
// Whether split loop
bool split_loop_;
// Whether we are inside double buffer scope.
bool in_double_buffer_scope_{false};
// The current loop next
std::vector<const For*> loop_nest_;
// The allocs to be appended before the loop
std::unordered_map<const For*, std::vector<Stmt> > loop_allocs_;
// The stmt to be appended before the loop
std::unordered_map<const For*, std::vector<Stmt> > loop_pre_;
// The allocation size of the buffer
std::unordered_map<const Variable*, StorageEntry> dbuffer_info_;
};


Stmt InjectDoubleBuffer(Stmt stmt, bool split_loop) {
return DoubleBufferInjector(split_loop).Inject(stmt);
}
} // namespace ir
} // namespace tvm
Loading

0 comments on commit a45d3b0

Please sign in to comment.