Skip to content

Commit

Permalink
[ir] Remove unnecessary field_dims_ in ArgLoadStmt (taichi-dev#7755)
Browse files Browse the repository at this point in the history
Issue: #

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 3b1f7b3</samp>

This pull request refactors the `ArgLoadStmt` class and the related
passes to simplify the handling of external tensors. It removes the
redundant fields `external_dims` and `field_dims_` from the
`ArgLoadStmt` class, and updates the `lower_matrix_ptr` and `type_check`
passes accordingly.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at 3b1f7b3</samp>

* Remove the external_dims field and the field_dims_ field from the
ArgLoadStmt class to simplify the handling of external tensors
([link](https://github.com/taichi-dev/taichi/pull/7755/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L591-L593),
[link](https://github.com/taichi-dev/taichi/pull/7755/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260L183-R183),
[link](https://github.com/taichi-dev/taichi/pull/7755/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260L190-L192),
[link](https://github.com/taichi-dev/taichi/pull/7755/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260L201-R200))
* Update the ret_type of the base_ptr of the ExternalPtrStmt in the
lower_matrix_ptr pass to match the ret_type of the ExternalPtrStmt with
flattened indices
([link](https://github.com/taichi-dev/taichi/pull/7755/files?diff=unified&w=0#diff-9b36b48490841b4018aca81632ae1beac3b2fdf1ee95a5c65eb42b676654b82eL69-R72))
* Simplify the logic of the type check pass for the ExternalPtrStmt by
using the ret_type of the base_ptr as the default ret_type of the
ExternalPtrStmt
([link](https://github.com/taichi-dev/taichi/pull/7755/files?diff=unified&w=0#diff-dd572dab7be4dbb5edc1043d6d6339b931ef35198b8657761ebf45a83e76ac2bL453-R459))
  • Loading branch information
ailzhang authored Apr 10, 2023
1 parent 859e375 commit ea8002e
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 25 deletions.
3 changes: 0 additions & 3 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,6 @@ void ExternalTensorExpression::flatten(FlattenContext *ctx) {
auto ptr = Stmt::make<ArgLoadStmt>(arg_id, prim_dt, /*is_ptr=*/true,
/*is_grad=*/is_grad);

int external_dims = dim - std::abs(element_dim);
ptr->cast<ArgLoadStmt>()->set_extern_dims(external_dims);

ptr->tb = tb;
ctx->push_back(std::move(ptr));
stmt = ctx->back_stmt();
Expand Down
10 changes: 1 addition & 9 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,13 @@ class ArgLoadStmt : public Stmt {
ndarray, ...
Therefore we need to add a field to indicate the type of the argument. For
now, only "is_ptr" and "field_dims" is needed.
now, only "is_ptr" is needed.
*/
bool is_ptr;

bool is_grad;

// field_dims of ndarray
int field_dims_ = 0;

ArgLoadStmt(int arg_id,
const DataType &dt,
bool is_ptr = false,
Expand All @@ -198,14 +195,9 @@ class ArgLoadStmt : public Stmt {
this->ret_type = dt;
this->is_ptr = is_ptr;
this->is_grad = is_grad;
this->field_dims_ = -1; // -1 means uninitialized
TI_STMT_REG_FIELDS;
}

void set_extern_dims(int dims) {
this->field_dims_ = dims;
}

bool has_global_side_effect() const override {
return false;
}
Expand Down
5 changes: 4 additions & 1 deletion taichi/transforms/lower_matrix_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ class LowerMatrixPtr : public BasicStmtVisitor {
auto fused = std::make_unique<ExternalPtrStmt>(
origin->base_ptr, indices, element_shape, element_dim);
fused->ret_type = stmt->ret_type;

// Note: Update base_ptr's ret_type so that it matches the ExternalPtrStmt
// with flattened indices. Main goal is to keep all the hacks in a single
// place so that they're easier to remove
origin->base_ptr->as<ArgLoadStmt>()->ret_type = stmt->ret_type;
stmt->replace_usages_with(fused.get());
modifier_.insert_before(stmt, std::move(fused));
modifier_.erase(stmt);
Expand Down
13 changes: 1 addition & 12 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,24 +450,13 @@ class TypeCheck : public IRVisitor {
}

void visit(ExternalPtrStmt *stmt) override {
/* ExternalPtrStmt may have two different semantics:
1. outer indexing to an argloaded external tensor
2. outer indexing + inner indexing to get the innermost primitive
element of an external tensor
We rely on "external_dims" and "indices" to distinguish these two cases.
Case #1: external_dims == indices.size(), return TensorType
Case #2: external_dims < indices.size(), return PrimitiveType
*/
TI_ASSERT(stmt->base_ptr->is<ArgLoadStmt>());
auto arg_load_stmt = stmt->base_ptr->cast<ArgLoadStmt>();

int external_dims = arg_load_stmt->field_dims_;
if (stmt->overrided_dtype) {
// pass
} else if (external_dims == stmt->indices.size() || external_dims == -1) {
stmt->ret_type = arg_load_stmt->ret_type;
} else {
stmt->ret_type = arg_load_stmt->ret_type.ptr_removed().get_element_type();
stmt->ret_type = arg_load_stmt->ret_type;
}

stmt->ret_type.set_is_pointer(true);
Expand Down

0 comments on commit ea8002e

Please sign in to comment.