Skip to content

Commit

Permalink
[ir] [refactor] Remove ptr_if_global in C++ Expr class (taichi-dev#3285)
Browse files Browse the repository at this point in the history
* Remove ptr_if_global

* Auto Format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
strongoier and taichi-gardener authored Oct 26, 2021
1 parent 059e5ff commit 1fe999f
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 43 deletions.
25 changes: 6 additions & 19 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ Expr &Expr::operator=(const Expr &o) {
if (expr == nullptr) {
set(o.eval());
} else if (expr->is_lvalue()) {
current_ast_builder().insert(std::make_unique<FrontendAssignStmt>(
ptr_if_global(*this), load_if_ptr(o)));
current_ast_builder().insert(
std::make_unique<FrontendAssignStmt>(*this, load_if_ptr(o)));
} else {
// set(o.eval());
TI_ERROR("Cannot assign to non-lvalue: {}", serialize());
Expand Down Expand Up @@ -140,17 +140,17 @@ Expr Expr::eval() const {

void Expr::operator+=(const Expr &o) {
if (this->atomic) {
(*this) = Expr::make<AtomicOpExpression>(
AtomicOpType::add, ptr_if_global(*this), load_if_ptr(o));
(*this) = Expr::make<AtomicOpExpression>(AtomicOpType::add, *this,
load_if_ptr(o));
} else {
(*this) = (*this) + o;
}
}

void Expr::operator-=(const Expr &o) {
if (this->atomic) {
(*this) = Expr::make<AtomicOpExpression>(
AtomicOpType::sub, ptr_if_global(*this), load_if_ptr(o));
(*this) = Expr::make<AtomicOpExpression>(AtomicOpType::sub, *this,
load_if_ptr(o));
} else {
(*this) = (*this) - o;
}
Expand Down Expand Up @@ -186,19 +186,6 @@ Expr load_if_ptr(const Expr &ptr) {
return ptr;
}

Expr ptr_if_global(const Expr &var) {
if (var.is<GlobalVariableExpression>()) {
// singleton global variable
TI_ASSERT_INFO(var.snode()->num_active_indices == 0,
"Please always use 'x[None]' (instead of simply 'x') to "
"access any 0-D field.");
return var[ExprGroup()];
} else {
// may be any local or global expr
return var;
}
}

Expr Var(const Expr &x) {
auto var = Expr(std::make_shared<IdExpression>());
current_ast_builder().insert(std::make_unique<FrontendAllocaStmt>(
Expand Down
5 changes: 0 additions & 5 deletions taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,6 @@ Expr bit_cast(const Expr &input) {
}

Expr load_if_ptr(const Expr &ptr);
Expr ptr_if_global(const Expr &var);

inline Expr smart_load(const Expr &var) {
return load_if_ptr(ptr_if_global(var));
}

// Begin: legacy frontend functions
Expr Var(const Expr &x);
Expand Down
6 changes: 3 additions & 3 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class UnaryOpExpression : public Expression {
DataType cast_type;

UnaryOpExpression(UnaryOpType type, const Expr &operand)
: type(type), operand(smart_load(operand)) {
: type(type), operand(load_if_ptr(operand)) {
cast_type = PrimitiveType::unknown;
}

Expand All @@ -288,8 +288,8 @@ class BinaryOpExpression : public Expression {

BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs)
: type(type) {
this->lhs.set(smart_load(lhs));
this->rhs.set(smart_load(rhs));
this->lhs.set(load_if_ptr(lhs));
this->rhs.set(load_if_ptr(rhs));
}

void serialize(std::ostream &ss) override {
Expand Down
27 changes: 11 additions & 16 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ Expr expr_index(const Expr &expr, const Expr &index) {
return expr[index];
}

void expr_assign(const Expr &lhs_, const Expr &rhs, std::string tb) {
auto lhs = ptr_if_global(lhs_);
void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb) {
TI_ASSERT(lhs->is_lvalue());
auto stmt = std::make_unique<FrontendAssignStmt>(lhs, load_if_ptr(rhs));
stmt->set_tb(tb);
Expand Down Expand Up @@ -609,38 +608,34 @@ void export_lang(py::module &m) {
static_cast<Expr (*)(const Expr &expr, DataType)>(bit_cast));

m.def("expr_atomic_add", [&](const Expr &a, const Expr &b) {
return Expr::make<AtomicOpExpression>(AtomicOpType::add, ptr_if_global(a),
load_if_ptr(b));
return Expr::make<AtomicOpExpression>(AtomicOpType::add, a, load_if_ptr(b));
});

m.def("expr_atomic_sub", [&](const Expr &a, const Expr &b) {
return Expr::make<AtomicOpExpression>(AtomicOpType::sub, ptr_if_global(a),
load_if_ptr(b));
return Expr::make<AtomicOpExpression>(AtomicOpType::sub, a, load_if_ptr(b));
});

m.def("expr_atomic_min", [&](const Expr &a, const Expr &b) {
return Expr::make<AtomicOpExpression>(AtomicOpType::min, ptr_if_global(a),
load_if_ptr(b));
return Expr::make<AtomicOpExpression>(AtomicOpType::min, a, load_if_ptr(b));
});

m.def("expr_atomic_max", [&](const Expr &a, const Expr &b) {
return Expr::make<AtomicOpExpression>(AtomicOpType::max, ptr_if_global(a),
load_if_ptr(b));
return Expr::make<AtomicOpExpression>(AtomicOpType::max, a, load_if_ptr(b));
});

m.def("expr_atomic_bit_and", [&](const Expr &a, const Expr &b) {
return Expr::make<AtomicOpExpression>(AtomicOpType::bit_and,
ptr_if_global(a), load_if_ptr(b));
return Expr::make<AtomicOpExpression>(AtomicOpType::bit_and, a,
load_if_ptr(b));
});

m.def("expr_atomic_bit_or", [&](const Expr &a, const Expr &b) {
return Expr::make<AtomicOpExpression>(AtomicOpType::bit_or,
ptr_if_global(a), load_if_ptr(b));
return Expr::make<AtomicOpExpression>(AtomicOpType::bit_or, a,
load_if_ptr(b));
});

m.def("expr_atomic_bit_xor", [&](const Expr &a, const Expr &b) {
return Expr::make<AtomicOpExpression>(AtomicOpType::bit_xor,
ptr_if_global(a), load_if_ptr(b));
return Expr::make<AtomicOpExpression>(AtomicOpType::bit_xor, a,
load_if_ptr(b));
});

m.def("expr_add", expr_add);
Expand Down

0 comments on commit 1fe999f

Please sign in to comment.