Skip to content

Commit

Permalink
[Opt] Fix redundant clone of stmts across offloaded tasks (taichi-dev…
Browse files Browse the repository at this point in the history
  • Loading branch information
ailzhang authored May 6, 2023
1 parent 0599ecc commit f6ebff3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 22 deletions.
44 changes: 22 additions & 22 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ bool demotable_axis_load(Stmt *stmt) {
// Stmt involving simple arithmetic of ExternalTensorShapeAlongAxisStmt
// shouldn't be saved in global tmp, just clone them to each shader
// separately.
if (stmt->is<GlobalLoadStmt>())
return false;
int n_op = stmt->num_operands();
if (n_op == 0) {
return stmt->is<ExternalTensorShapeAlongAxisStmt>() ||
Expand Down Expand Up @@ -442,8 +444,8 @@ class PromoteIntermediateToGlobalTmp : public BasicStmtVisitor {

private:
explicit PromoteIntermediateToGlobalTmp(
const StmtToOffsetMap &local_to_global_offset)
: local_to_global_offset_(local_to_global_offset) {
const StmtToOffsetMap *local_to_global_offset)
: local_to_global_offset_(*local_to_global_offset) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}
Expand All @@ -454,20 +456,20 @@ class PromoteIntermediateToGlobalTmp : public BasicStmtVisitor {
local_to_global_offset_.find(stmt) != local_to_global_offset_.end() &&
stored_to_global_.find(stmt) == stored_to_global_.end()) {
stored_to_global_.insert(stmt);
auto offset = local_to_global_offset_[stmt];
auto offset = local_to_global_offset_.at(stmt);
auto ptr = stmt->insert_after_me(
Stmt::make<GlobalTemporaryStmt>(offset, stmt->ret_type));
ptr->insert_after_me(Stmt::make<GlobalStoreStmt>(ptr, stmt));
}
}

static void run(IRNode *root, const StmtToOffsetMap &local_to_global_offset) {
static void run(IRNode *root, const StmtToOffsetMap *local_to_global_offset) {
PromoteIntermediateToGlobalTmp pass(local_to_global_offset);
root->accept(&pass);
}

private:
StmtToOffsetMap local_to_global_offset_;
const StmtToOffsetMap &local_to_global_offset_;
std::set<Stmt *> stored_to_global_;
};

Expand All @@ -477,11 +479,11 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
private:
FixCrossOffloadReferences(
const CompileConfig &config,
const StmtToOffsetMap &local_to_global_offset,
const std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded,
const StmtToOffsetMap *local_to_global_offset,
std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded,
OffloadedRanges *offloaded_ranges)
: config_(config),
local_to_global_offset_(local_to_global_offset),
local_to_global_offset_(*local_to_global_offset),
stmt_to_offloaded_(stmt_to_offloaded),
offloaded_ranges_(offloaded_ranges) {
allow_undefined_visitor = true;
Expand All @@ -499,9 +501,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
offloaded_ranges_->begin_stmts.find(stmt)->second) !=
local_to_global_offset_.end(),
"Begin fails.")
stmt->begin_offset =
local_to_global_offset_[offloaded_ranges_->begin_stmts.find(stmt)
->second];
stmt->begin_offset = local_to_global_offset_.at(
offloaded_ranges_->begin_stmts.find(stmt)->second);
}
if (!stmt->const_end) {
if (stmt->end_stmt) {
Expand All @@ -514,9 +515,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
offloaded_ranges_->end_stmts.find(stmt)->second) !=
local_to_global_offset_.end(),
"End fails.")
stmt->end_offset =
local_to_global_offset_[offloaded_ranges_->end_stmts.find(stmt)
->second];
stmt->end_offset = local_to_global_offset_.at(
offloaded_ranges_->end_stmts.find(stmt)->second);
}
}
}
Expand All @@ -530,7 +530,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
auto ret_type = stmt->ret_type;
local_to_global_vector_type_[stmt] = ret_type;
auto ptr = replacement.push_back<GlobalTemporaryStmt>(
local_to_global_offset_[stmt], ret_type);
local_to_global_offset_.at(stmt), ret_type);
auto offloaded = stmt_to_offloaded_[stmt];
stmt_to_offloaded_[ptr] = offloaded;
if (auto tensor_type = stmt->ret_type->cast<TensorType>()) {
Expand Down Expand Up @@ -623,7 +623,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
generic_visit(pcopy);
} else {
auto global_temporary = Stmt::make<GlobalTemporaryStmt>(
local_to_global_offset_[op], op->ret_type);
local_to_global_offset_.at(op), op->ret_type);
stmt_to_offloaded_[global_temporary.get()] = offloaded;
stmt->set_operand(index, global_temporary.get());
if (op->is<AllocaStmt>() || op->ret_type.is_pointer()) {
Expand Down Expand Up @@ -660,8 +660,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
public:
static void run(IRNode *root,
const CompileConfig &config,
const StmtToOffsetMap &local_to_global_offset,
const std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded,
const StmtToOffsetMap *local_to_global_offset,
std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded,
OffloadedRanges *offloaded_ranges) {
FixCrossOffloadReferences pass(config, local_to_global_offset,
stmt_to_offloaded, offloaded_ranges);
Expand All @@ -670,8 +670,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {

private:
[[maybe_unused]] const CompileConfig &config_;
StmtToOffsetMap local_to_global_offset_;
std::unordered_map<Stmt *, Stmt *> stmt_to_offloaded_;
const StmtToOffsetMap &local_to_global_offset_;
std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded_;
OffloadedRanges *const offloaded_ranges_;
std::unordered_map<Stmt *, DataType> local_to_global_vector_type_;
};
Expand Down Expand Up @@ -783,9 +783,9 @@ void offload(IRNode *root, const CompileConfig &config) {
auto stmt_to_offloaded = StmtToOffloaded::run(root);
const auto local_to_global_offset = IdentifyValuesUsedInOtherOffloads::run(
root, config, stmt_to_offloaded, &offloaded_ranges);
PromoteIntermediateToGlobalTmp::run(root, local_to_global_offset);
PromoteIntermediateToGlobalTmp::run(root, &local_to_global_offset);
stmt_to_offloaded = StmtToOffloaded::run(root);
FixCrossOffloadReferences::run(root, config, local_to_global_offset,
FixCrossOffloadReferences::run(root, config, &local_to_global_offset,
stmt_to_offloaded, &offloaded_ranges);
}
insert_gc(root, config);
Expand Down
21 changes: 21 additions & 0 deletions tests/python/test_offload_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,24 @@ def run(a: ti.i32):
print("OK")

run(2)


@test_utils.test(exclude=ti.amdgpu)
def test_offload_with_save():
a = ti.Vector.field(2, dtype=ti.f32, shape=1)
b = ti.Vector.field(2, dtype=ti.f32, shape=1)
c = ti.Vector.field(2, dtype=ti.f32, shape=1)

@ti.kernel
def test():
a[0] = ti.Vector([1, 1])
b[0] = ti.Vector([0, 0])
c[0] = ti.Vector([0, 0])
b[0] += a[0] # b[0] = [1, 1]
b[0] /= 2 # b[0] = [0.5, 0.5]
for i in c:
c[i] += b[0] # c[0] = [0.5, 0.5]

test()
assert c[0][0] == 0.5
assert c[0][1] == 0.5

0 comments on commit f6ebff3

Please sign in to comment.