Skip to content

Commit

Permalink
[TIR][REFACTOR] Migrate low-level passes in tvm.lower to the Unified …
Browse files Browse the repository at this point in the history
…IR pass manager. (apache#5364)

- Migrate BoundCheckers and Simplify
- Migrate RewriteUnsafeSelect and RemoveNoOp
- Migrate UnrollLoop and StorageRewrite
- Migrate InjectDoubleBuffer and InjectVirtualThread
- Migrate LoopPartition and Vectorize
- Migrate CoProcSync, LiftAttrScope, InjectCopyIntrin

We still keep ir_pass registerations for now.
Need a separate PR to refactor the parts before the StorageFlatten.
  • Loading branch information
tqchen authored Apr 18, 2020
1 parent fbcf61a commit 3264895
Show file tree
Hide file tree
Showing 37 changed files with 1,026 additions and 458 deletions.
1 change: 0 additions & 1 deletion include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ struct ExprDeepEqual {
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};


/*!
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
Expand Down
140 changes: 0 additions & 140 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,144 +202,13 @@ Stmt RewriteForTensorCore(Stmt stmt,
*/
bool VerifyCompactBuffer(Stmt stmt);

/*!
* \brief Remove No Op from the Stmt.
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt RemoveNoOp(Stmt stmt);

/*!
* \brief unroll the constant loop marked by unroll.
* This pass also automatically attach pragma unroll tag to loops which meets the standard.
*
* \param stmt The statment to be unrolled.
* \param auto_max_step The maximum step before stop attach automatic unroll
* \param auto_max_depth The maximum depth before stop attach automatic unroll
* \param auto_max_extent The maximum extent of the loop we can unroll,
* this is an legacy option that do not take the loop total steps into account.
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
* \return Transformed stmt.
*/
Stmt UnrollLoop(Stmt stmt,
int auto_max_step,
int auto_max_depth,
int auto_max_extent,
bool explicit_unroll);

/*!
* \brief vectorize the constant loops
* \param stmt The statement to be vectorized.
* \return Transformed stmt.
*/
Stmt VectorizeLoop(Stmt stmt);

/*!
* \brief convert vectorized loops into serialized loops
* \param stmt The statement to skip vectorization on.
* \return Transformed stmt.
*/
Stmt SkipVectorize(Stmt stmt);

/*!
* \brief instruments bound checkers.
* \param stmt The statement to be instrumented.
* \return Instrumented stmt.
*/
Stmt InstrumentBoundCheckers(Stmt stmt);

/*!
* \brief Inject virtual thread loops into stmt.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectVirtualThread(Stmt stmt);

/*!
* \brief Inject prefetch instructions into stmt.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectPrefetch(Stmt stmt);

/*!
* \brief Inject double buffer into stmt.
* \param stmt The statement to be transformed.
* \param split_loop Loop splitting factor.
* \return Transformed stmt.
*/
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);

/*!
* \brief Inject copy intrinsics with optional pad.
*
* \param stmt The statement to be transformed.
* \param pragma_key The pragma key for hint of copy.
* \param fintrin The function with signature
*
* Stmt fintrin(Buffer src,
* Buffer dst,
* Array<Expr> pad_before,
* Array<Expr> pad_after,
* Expr pad_value)
* \return Transformed stmt.
*/
Stmt InjectCopyIntrin(Stmt stmt,
const std::string& pragma_key,
const runtime::PackedFunc& fintrin);

/*!
* \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make
* a static allocation plan when possible.
*
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt StorageRewrite(Stmt stmt);

/*!
* \brief partition loops in the stmt
* \param stmt The stmt to do loop partition
* \param split_const_loop flag to enable partition for const loop
* \return Transformed stmt.
*/
Stmt LoopPartition(Stmt stmt, bool split_const_loop);

/*!
* \brief Detect and insert sync points to co-processor.
*
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt CoProcSync(Stmt stmt);

/*!
* \brief Lift common attrs with attr_key to outer scope.
*
* \param stmt The stmt to be transformed
* \param attr_key The attribute key to be checked.
* \return Transformed stmt.
*/
Stmt LiftAttrScope(Stmt stmt, std::string attr_key);

/*!
* \brief Detect and rewrite unsafe select that contains memory access.
* \param stmt The statement to be rewritten.
* \return Transformed stmt.
*/
Stmt RewriteUnsafeSelect(Stmt stmt);

/*!
* \brief Lower attached storage access information.
* Do this pass after all storage access analysis finish.
*
* \param stmt The stmt to be transformed
* \return Transformed stmt.
*/
Stmt LowerStorageAccessInfo(Stmt stmt);

/*!
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
Expand All @@ -356,15 +225,6 @@ Stmt DecorateDeviceScope(Stmt stmt);
*/
Stmt HoistIfThenElse(Stmt stmt);

/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
* \note Run this pass after StorageFlatten.
* \param stmt The stmt to do datatype rewrite
* \param target_bits the bit of target datatype
* \return Transformed stmt.
*/
Stmt NarrowDataType(Stmt stmt, int target_bits);

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
Expand Down
118 changes: 118 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,124 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const std::string& name,
const tvm::Array<runtime::String>& required);

/*!
* \brief Inject copy intrinsics with optional pad.
*
* \param pragma_key The pragma key for hint of copy.
* \param fintrin The function with signature
*
* Stmt fintrin(Buffer src,
* Buffer dst,
* Array<Expr> pad_before,
* Array<Expr> pad_after,
* Expr pad_value)
* \return The pass.
*/
TVM_DLL Pass InjectCopyIntrin(std::string pragma_key,
runtime::PackedFunc fintrin);

/*!
* \brief Detect and insert sync points to co-processor.
*
* \return The pass.
*/
TVM_DLL Pass CoProcSync();

/*!
* \brief Lift common attrs with attr_key to outer scope.
*
* \param attr_key The attribute key to be checked.
* \return The pass.
*/
TVM_DLL Pass LiftAttrScope(std::string attr_key);

/*!
* \brief partition loops in the stmt.
*
* \param split_const_loop flag to enable partition for const loop
*
* \return The pass.
*/
TVM_DLL Pass LoopPartition(bool split_const_loop);

/*!
* \brief Lower vectorization loops.
*
* \param enable_vectorize Whether vectorization is enabled.
*
* \return The pass.
*/
TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);

/*!
* \brief Inject virtual thread loops.
*
* \return The pass.
*/
TVM_DLL Pass InjectVirtualThread();

/*!
* \brief Inject double buffer statements.
*
* \param split_loop_factor Loop splitting factor.
* \return The pass.
*/
TVM_DLL Pass InjectDoubleBuffer(int split_loop_factor);

/*!
* \brief Rewrite storage allocation pattern.
* Moves the allocation to outer most possible scope.
* Trying to share space between allocations to make
* a static allocation plan when possible.
*
* \return The pass.
*/
TVM_DLL Pass StorageRewrite();

/*!
* \brief unroll the constant loop marked by unroll.
* This pass also automatically attach pragma unroll tag to loops which meets the standard.
*
* \param auto_max_step The maximum step before stop attach automatic unroll
* \param auto_max_depth The maximum depth before stop attach automatic unroll
* \param auto_max_extent The maximum extent of the loop we can unroll,
* this is an legacy option that do not take the loop total steps into account.
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
* \return The pass.
*/
TVM_DLL Pass UnrollLoop(int auto_max_step,
int auto_max_depth,
int auto_max_extent,
bool explicit_unroll);

/*!
* \brief Remove No Op from the Stmt.
*
* \return The pass.
*/
TVM_DLL Pass RemoveNoOp();

/*!
* \brief Detect and rewrite unsafe select that contains memory access.
*
* \return The pass.
*/
TVM_DLL Pass RewriteUnsafeSelect();

/*!
* \brief Run arithmetic simplifications on the statements and expressions.
*
* \return The pass.
*/
TVM_DLL Pass Simplify();

/*!
* \brief Instruments bound checkers.
*
* \return The pass.
*/
TVM_DLL Pass InstrumentBoundCheckers();

/*!
* \brief Transform the high-level PrimFunc to a low-level version
* that can be used as an API function.
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def lower(sch,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit)

for f in lower_phase2:
stmt = f(stmt)

Expand All @@ -187,11 +188,14 @@ def lower(sch,
stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)

for f in lower_phase3:
stmt = f(stmt)

# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)

if simple_mode:
return stmt

Expand Down
Loading

0 comments on commit 3264895

Please sign in to comment.