Skip to content

Commit

Permalink
[async] [lang] [opt] Add ti.loop_unique(covers=...) to improve task d…
Browse files Browse the repository at this point in the history
…ependence analysis (taichi-dev#2163)

* Add loop_unique(covers=[snodes...])

* Make use of covers_snode in get_task_meta

* Improve DSE debug messages

* Fix covers frontend

* Improve same_statements and whole_kernel_cse

* [skip ci] enforce code format

* retrigger CI

* [skip ci] Apply suggestions from code review

Co-authored-by: Yuanming Hu <[email protected]>

Co-authored-by: Taichi Gardener <[email protected]>
Co-authored-by: Yuanming Hu <[email protected]>
  • Loading branch information
3 people authored Jan 19, 2021
1 parent 6c1c3c8 commit 996d7b1
Show file tree
Hide file tree
Showing 15 changed files with 188 additions and 53 deletions.
9 changes: 7 additions & 2 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,13 @@ def assume_in_range(val, base, low, high):
Expr(base).ptr, low, high)


def loop_unique(val):
return taichi_lang_core.expr_loop_unique(Expr(val).ptr)
def loop_unique(val, covers=None):
if covers is None:
covers = []
if not isinstance(covers, (list, tuple)):
covers = [covers]
covers = [x.snode.ptr if isinstance(x, Expr) else x.ptr for x in covers]
return taichi_lang_core.expr_loop_unique(Expr(val).ptr, covers)


parallelize = core.parallelize
Expand Down
8 changes: 4 additions & 4 deletions taichi/analysis/gather_uniquely_accessed_pointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
// Search SNodes that are uniquely accessed, i.e., accessed by
// one GlobalPtrStmt (or by definitely-same-address GlobalPtrStmts),
// and that GlobalPtrStmt's address is loop-unique.
std::unordered_map<SNode *, GlobalPtrStmt *> accessed_pointer_;
std::unordered_map<const SNode *, GlobalPtrStmt *> accessed_pointer_;

public:
using BasicStmtVisitor::visit;
Expand Down Expand Up @@ -144,7 +144,7 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
}
}

static std::unordered_map<SNode *, GlobalPtrStmt *> run(IRNode *root) {
static std::unordered_map<const SNode *, GlobalPtrStmt *> run(IRNode *root) {
TI_ASSERT(root->is<OffloadedStmt>());
auto offload = root->as<OffloadedStmt>();
UniquelyAccessedSNodeSearcher searcher;
Expand All @@ -164,8 +164,8 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor {
};

namespace irpass::analysis {
std::unordered_map<SNode *, GlobalPtrStmt *> gather_uniquely_accessed_pointers(
IRNode *root) {
std::unordered_map<const SNode *, GlobalPtrStmt *>
gather_uniquely_accessed_pointers(IRNode *root) {
// TODO: What about SNodeOpStmts?
return UniquelyAccessedSNodeSearcher::run(root);
}
Expand Down
67 changes: 46 additions & 21 deletions taichi/analysis/same_statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,40 +168,65 @@ class IRNodeComparator : public IRVisitor {
}
}

if (check_same_value_ && stmt->is<GlobalPtrStmt>()) {
// Special case: we do not care the "activate" field when checking
// whether two global pointers share the same value.
// And we cannot use irpass::analysis::definitely_same_address()
// directly because that function does not support id_map.
bool field_checked = false;
if (check_same_value_) {
if (stmt->is<GlobalPtrStmt>()) {
// Special case: we do not care about the "activate" field when checking
// whether two global pointers share the same value.
// And we cannot use irpass::analysis::definitely_same_address()
// directly because that function does not support id_map.

// TODO: Update this part if GlobalPtrStmt comes to have more fields
TI_ASSERT(stmt->width() == 1);
if (stmt->as<GlobalPtrStmt>()->snodes[0]->id !=
other->as<GlobalPtrStmt>()->snodes[0]->id) {
same = false;
return;
// TODO: Update this part if GlobalPtrStmt comes to have more fields
TI_ASSERT(stmt->width() == 1);
if (stmt->as<GlobalPtrStmt>()->snodes[0]->id !=
other->as<GlobalPtrStmt>()->snodes[0]->id) {
same = false;
return;
}
field_checked = true;
} else if (stmt->is<LoopUniqueStmt>()) {
// Special case: we do not care the "covers" field when checking
// whether two LoopUniqueStmts share the same value.
field_checked = true;
} else if (stmt->is<RangeAssumptionStmt>()) {
// Special case: we do not care the "low, high" fields when checking
// whether two RangeAssumptionStmts share the same value.
field_checked = true;
}
} else {
}
if (!field_checked) {
// field check
if (!stmt->field_manager.equal(other->field_manager)) {
same = false;
return;
}
}

// operand check
if (stmt->num_operands() != other->num_operands()) {
same = false;
return;
bool operand_checked = false;
if (check_same_value_) {
if (stmt->is<RangeAssumptionStmt>()) {
// Special case: we do not care about the "base" operand when checking
// whether two RangeAssumptionStmts share the same value.
check_mapping(stmt->as<RangeAssumptionStmt>()->input,
other->as<RangeAssumptionStmt>()->input);
operand_checked = true;
}
}
for (int i = 0; i < stmt->num_operands(); i++) {
if ((stmt->operand(i) == nullptr) != (other->operand(i) == nullptr)) {
if (!operand_checked) {
// operand check
if (stmt->num_operands() != other->num_operands()) {
same = false;
return;
}
if (stmt->operand(i) == nullptr)
continue;
check_mapping(stmt->operand(i), other->operand(i));
for (int i = 0; i < stmt->num_operands(); i++) {
if ((stmt->operand(i) == nullptr) != (other->operand(i) == nullptr)) {
same = false;
return;
}
if (stmt->operand(i) == nullptr)
continue;
check_mapping(stmt->operand(i), other->operand(i));
}
}

map_id(stmt->id, other->id);
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ std::pair<std::unordered_set<SNode *>, std::unordered_set<SNode *>>
gather_snode_read_writes(IRNode *root);
std::vector<Stmt *> gather_statements(IRNode *root,
const std::function<bool(Stmt *)> &test);
std::unordered_map<SNode *, GlobalPtrStmt *> gather_uniquely_accessed_pointers(
IRNode *root);
std::unordered_map<const SNode *, GlobalPtrStmt *>
gather_uniquely_accessed_pointers(IRNode *root);
std::unique_ptr<std::unordered_set<AtomicOpStmt *>> gather_used_atomics(
IRNode *root);
std::vector<Stmt *> get_load_pointers(Stmt *load_stmt);
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/frontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ inline Expr AssumeInRange(const Expr &expr,
return Expr::make<RangeAssumptionExpression>(expr, base, low, high);
}

inline Expr LoopUnique(const Expr &input) {
return Expr::make<LoopUniqueExpression>(load_if_ptr(input));
inline Expr LoopUnique(const Expr &input, const std::vector<SNode *> &covers) {
return Expr::make<LoopUniqueExpression>(load_if_ptr(input), covers);
}

void insert_snode_access_flag(SNodeAccessFlag v, const Expr &field);
Expand Down
17 changes: 16 additions & 1 deletion taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,24 @@ void RangeAssumptionExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

std::string LoopUniqueExpression::serialize() {
std::string result = "loop_unique(" + input->serialize();
for (int i = 0; i < covers.size(); i++) {
if (i == 0)
result += ", covers=[";
result += covers[i]->get_node_type_name_hinted();
if (i == (int)covers.size() - 1)
result += "]";
else
result += ", ";
}
result += ")";
return result;
}

void LoopUniqueExpression::flatten(FlattenContext *ctx) {
input->flatten(ctx);
ctx->push_back(Stmt::make<LoopUniqueStmt>(input->stmt));
ctx->push_back(Stmt::make<LoopUniqueStmt>(input->stmt, covers));
stmt = ctx->back_stmt();
}

Expand Down
8 changes: 4 additions & 4 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,13 +453,13 @@ class RangeAssumptionExpression : public Expression {
class LoopUniqueExpression : public Expression {
public:
Expr input;
std::vector<SNode *> covers;

LoopUniqueExpression(const Expr &input) : input(input) {
LoopUniqueExpression(const Expr &input, const std::vector<SNode *> &covers)
: input(input), covers(covers) {
}

std::string serialize() override {
return fmt::format("loop_unique({})", input.serialize());
}
std::string serialize() override;

void flatten(FlattenContext *ctx) override;
};
Expand Down
38 changes: 37 additions & 1 deletion taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ GlobalPtrStmt::GlobalPtrStmt(const LaneAttribute<SNode *> &snodes,
TI_STMT_REG_FIELDS;
}

bool GlobalPtrStmt::is_element_wise(SNode *snode) const {
bool GlobalPtrStmt::is_element_wise(const SNode *snode) const {
if (snode == nullptr) {
// check every SNode when "snode" is nullptr
for (const auto &snode_i : snodes.data) {
Expand All @@ -88,6 +88,18 @@ bool GlobalPtrStmt::is_element_wise(SNode *snode) const {
return true;
}

bool GlobalPtrStmt::covers_snode(const SNode *snode) const {
// Check if the addresses of this statement all over the loop cover
// all active indices of the snode.
for (auto &index : indices) {
if (auto loop_unique = index->cast<LoopUniqueStmt>()) {
if (loop_unique->covers_snode(snode))
return true;
}
}
return is_element_wise(snode);
}

SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type,
SNode *snode,
Stmt *ptr,
Expand All @@ -112,6 +124,30 @@ ExternalTensorShapeAlongAxisStmt::ExternalTensorShapeAlongAxisStmt(int axis,
TI_STMT_REG_FIELDS;
}

LoopUniqueStmt::LoopUniqueStmt(Stmt *input, const std::vector<SNode *> &covers)
: input(input) {
for (const auto &sn : covers) {
if (sn->is_place()) {
TI_INFO(
"A place SNode {} appears in the 'covers' parameter "
"of 'ti.loop_unique'. It is recommended to use its parent "
"(x.parent()) instead.",
sn->get_node_type_name_hinted());
this->covers.insert(sn->parent->id);
} else
this->covers.insert(sn->id);
}
TI_STMT_REG_FIELDS;
}

bool LoopUniqueStmt::covers_snode(const SNode *snode) const {
if (snode->is_place()) {
return covers.count(snode->parent->id) > 0;
} else {
TI_NOT_IMPLEMENTED
}
}

Stmt *LocalLoadStmt::previous_store_or_alloca_in_block() {
int position = parent->locate(this);
// TI_ASSERT(width() == 1);
Expand Down
15 changes: 10 additions & 5 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ class GlobalPtrStmt : public Stmt {
const std::vector<Stmt *> &indices,
bool activate = true);

bool is_element_wise(SNode *snode) const;
bool is_element_wise(const SNode *snode) const;

bool covers_snode(const SNode *snode) const;

bool has_global_side_effect() const override {
return activate;
Expand Down Expand Up @@ -334,16 +336,19 @@ class RangeAssumptionStmt : public Stmt {
class LoopUniqueStmt : public Stmt {
public:
Stmt *input;
std::unordered_set<int> covers; // Stores SNode id
// std::unordered_set<> provides operator==, and StmtFieldManager will
// use that to check if two LoopUniqueStmts are the same.

explicit LoopUniqueStmt(Stmt *input) : input(input) {
TI_STMT_REG_FIELDS;
}
LoopUniqueStmt(Stmt *input, const std::vector<SNode *> &covers);

bool covers_snode(const SNode *snode) const;

bool has_global_side_effect() const override {
return false;
}

TI_STMT_DEF_FIELDS(ret_type, input);
TI_STMT_DEF_FIELDS(ret_type, input, covers);
TI_DEFINE_ACCEPT_AND_CLONE
};

Expand Down
29 changes: 24 additions & 5 deletions taichi/program/async_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,14 +331,33 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
insert_value_states_top_down(root_stmt->snode);
}

// We are being conservative here: if there are any non-element-wise
// accesses (e.g., a = x[i + 1]), we don't treat it as completely
// overwriting the value state (e.g., for i in x: x[i] = 0).
for (auto &state : meta.output_states) {
// We need to insert input value states in case of partial writes.
// Assume we write sn on every indices we access in this task,
// because we would have inserted the input value state in
// get_meta_input_value_states otherwise.
if (state.type == AsyncState::Type::value && state.holds_snode()) {
const auto *sn = state.snode();
if (meta.element_wise.find(sn) == meta.element_wise.end() ||
!meta.element_wise[sn]) {
bool completely_overwriting = false;
if (meta.element_wise[sn]) {
// If every access on sn is element-wise, then it must be
// completely overwriting.
completely_overwriting = true;
// TODO: this is also completely overwriting although element_wise[sn]
// is false:
// for i in x:
// x[i] = 0
// x[i + 1] = 0
// A solution to this is to gather all definite writes in the task,
// and check if any one of them ->covers_snode(sn).
// TODO: is element-wise useless since it must be loop-unique?
}
if (meta.loop_unique.count(sn) > 0 && meta.loop_unique[sn] != nullptr) {
if (meta.loop_unique[sn]->covers_snode(sn)) {
completely_overwriting = true;
}
}
if (!completely_overwriting) {
meta.input_states.insert(state);
}
}
Expand Down
6 changes: 5 additions & 1 deletion taichi/program/async_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,13 @@ struct TaskMeta {
SNode *snode{nullptr}; // struct-for and listgen only
std::unordered_set<AsyncState> input_states;
std::unordered_set<AsyncState> output_states;
std::unordered_map<SNode *, GlobalPtrStmt *> loop_unique;

// loop_unique[s] != nullptr => injective access on s
std::unordered_map<const SNode *, GlobalPtrStmt *> loop_unique;
std::unordered_map<const SNode *, bool> element_wise;

// element_wise[s] OR loop_unique[s] covers s => surjective access on s

void print() const;
};

Expand Down
9 changes: 7 additions & 2 deletions taichi/program/ir_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,6 @@ std::pair<IRHandle, bool> IRBank::optimize_dse(
if (verbose) {
TI_INFO(" DSE: after CFG, modified={}", modified);
std::cout << std::flush;
irpass::print(new_ir.get());
std::cout << std::flush;
}

if (!modified) {
Expand All @@ -225,6 +223,13 @@ std::pair<IRHandle, bool> IRBank::optimize_dse(
irpass::flag_access(new_ir.get());
irpass::die(new_ir.get());

if (verbose) {
TI_INFO(" DSE: after flag_access and DIE");
std::cout << std::flush;
irpass::print(new_ir.get());
std::cout << std::flush;
}

ret_handle = IRHandle(new_ir.get(), get_hash(new_ir.get()));
insert(std::move(new_ir), ret_handle.hash());
return std::make_pair(ret_handle, false);
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/demote_atomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ TLANG_NAMESPACE_BEGIN

class DemoteAtomics : public BasicStmtVisitor {
private:
std::unordered_map<SNode *, GlobalPtrStmt *> loop_unique_ptr_;
std::unordered_map<const SNode *, GlobalPtrStmt *> loop_unique_ptr_;

public:
using BasicStmtVisitor::visit;
Expand Down
13 changes: 11 additions & 2 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,17 @@ class IRPrinter : public IRVisitor {
}

void visit(LoopUniqueStmt *stmt) override {
print("{}{} = loop_unique({})", stmt->type_hint(), stmt->name(),
stmt->input->name());
std::string add = "";
if (!stmt->covers.empty()) {
add = ", covers=[";
for (const auto &sn : stmt->covers) {
add += fmt::format("S{}, ", sn);
}
add.erase(add.size() - 2, 2); // remove the last ", "
add += "]";
}
print("{}{} = loop_unique({}{})", stmt->type_hint(), stmt->name(),
stmt->input->name(), add);
}

void visit(LinearizeStmt *stmt) override {
Expand Down
Loading

0 comments on commit 996d7b1

Please sign in to comment.