From 5385396bdddb441b65e506f3490dcb5ae8175d6e Mon Sep 17 00:00:00 2001 From: KnowingNothing Date: Mon, 11 May 2020 06:30:59 +0800 Subject: [PATCH] update --- test/ir_mutator.cc | 78 +++++++++++++++++++ ...54\344\270\200\351\203\250\345\210\206.md" | 4 +- ...54\344\272\214\351\203\250\345\210\206.md" | 44 ++++++++++- 3 files changed, 121 insertions(+), 5 deletions(-) create mode 100644 test/ir_mutator.cc diff --git a/test/ir_mutator.cc b/test/ir_mutator.cc new file mode 100644 index 0000000..776ef6a --- /dev/null +++ b/test/ir_mutator.cc @@ -0,0 +1,78 @@ +#include +#include + +#include "IR.h" +#include "IRMutator.h" +#include "IRVisitor.h" +#include "IRPrinter.h" +#include "type.h" + +using namespace Boost::Internal; + + +class MyMutator : public IRMutator { + public: + Expr visit(Ref op) override { + if (op->name == "A") { + return Var::make(op->type(), "modified_A", op->args, op->shape); + } + return IRMutator::visit(op); + } +}; + + +int main() { + const int M = 1024; + const int N = 512; + const int K = 256; + Type index_type = Type::int_scalar(32); + Type data_type = Type::float_scalar(32); + + // index i + Expr dom_i = Dom::make(index_type, 0, M); + Expr i = Index::make(index_type, "i", dom_i, IndexType::Spatial); + + // index j + Expr dom_j = Dom::make(index_type, 0, N); + Expr j = Index::make(index_type, "j", dom_j, IndexType::Spatial); + + // index k + Expr dom_k = Dom::make(index_type, 0, K); + Expr k = Index::make(index_type, "k", dom_k, IndexType::Reduce); + + // A + Expr expr_A = Var::make(data_type, "A", {i, k}, {M, K}); + + // B + Expr expr_B = Var::make(data_type, "B", {k, j}, {K, N}); + + // C + Expr expr_C = Var::make(data_type, "C", {i, j}, {M, N}); + + // main stmt + Stmt main_stmt = Move::make( + expr_C, + Binary::make(data_type, BinaryOpType::Add, expr_C, + Binary::make(data_type, BinaryOpType::Mul, expr_A, expr_B)), + MoveType::MemToMem + ); + + // loop nest + Stmt loop_nest = LoopNest::make({i, j, k}, {main_stmt}); + + // kernel + Group kernel = Kernel::make("simple_gemm", {expr_A, expr_B}, {expr_C}, {loop_nest}, KernelType::CPU); + + // mutator + MyMutator mutator; + kernel = mutator.mutate(kernel); + + // printer + IRPrinter printer; + std::string code = printer.print(kernel); + + std::cout << code; + + std::cout << "Success!\n"; + return 0; +} \ No newline at end of file diff --git "a/\347\274\226\350\257\221\345\244\247\344\275\234\344\270\232-\347\254\254\344\270\200\351\203\250\345\210\206.md" "b/\347\274\226\350\257\221\345\244\247\344\275\234\344\270\232-\347\254\254\344\270\200\351\203\250\345\210\206.md" index 5931a14..f8c60b2 100644 --- "a/\347\274\226\350\257\221\345\244\247\344\275\234\344\270\232-\347\254\254\344\270\200\351\203\250\345\210\206.md" +++ "b/\347\274\226\350\257\221\345\244\247\344\275\234\344\270\232-\347\254\254\344\270\200\351\203\250\345\210\206.md" @@ -16,7 +16,7 @@ CList: 常量列表 AList: 变量列表 ``` ```bnf -P ::= P S| S +P ::= P S | S S ::= LHS = RHS ; LHS ::= TRef RHS ::= RHS + RHS @@ -33,7 +33,7 @@ TRef ::= Id < CList > [ AList ] SRef ::= Id < CList > CList ::= CList , IntV | IntV AList ::= AList , IdExpr | IdExpr -IdExpr ::= Id | IdExpr + IdExpr | IdExpr + IntV | IdExpr * IntV | IdExpr // IntV | IdExpr % IntV | (IdExpr) +IdExpr ::= Id | IdExpr + IdExpr | IdExpr - IdExpr | IdExpr + IntV | IdExpr * IntV | IdExpr // IntV | IdExpr % IntV | (IdExpr) Const ::= FloatV | IntV ``` 另外,虽然文法允许张量名字出现在AList中,但是在语义上这是不允许的,所以在测试例子中不会出现这种输入。 diff --git "a/\347\274\226\350\257\221\345\244\247\344\275\234\344\270\232-\347\254\254\344\272\214\351\203\250\345\210\206.md" "b/\347\274\226\350\257\221\345\244\247\344\275\234\344\270\232-\347\254\254\344\272\214\351\203\250\345\210\206.md" index f4f6ccb..57e960f 100644 --- "a/\347\274\226\350\257\221\345\244\247\344\275\234\344\270\232-\347\254\254\344\272\214\351\203\250\345\210\206.md" +++ "b/\347\274\226\350\257\221\345\244\247\344\275\234\344\270\232-\347\254\254\344\272\214\351\203\250\345\210\206.md" @@ -22,7 +22,9 @@ loss = mse_loss(Y2, T) # loss is scalar #### 2.2 问题定义 __现在我们开始进行问题描述:__ -对于一个给定的表达式$Output = expr(Input_1, Input_2, ..., Input_n)$($Output, Input_i$是张量或标量, $expr()$表示用其参数构造一个表达式),我们如果已知了最终loss对于$Output$的导数$dOutput = \frac{\partial loss}{\partial Output}$,我们想知道loss对于某个输入的导函数是什么,也就是求$dInput_i = \frac{\partial loss}{\partial Input_i}$的问题,**我们要求必须从表达式IR层面分析出求导的表达式,而不能根据case的名字硬编码答案**。 +对于一个给定的表达式$Output = expr(Input_1, Input_2, ..., Input_n)$($Output, Input_i$是张量或标量, $expr()$表示用其参数构造一个表达式),我们如果已知了最终loss对于$Output$的导数$dOutput = \frac{\partial loss}{\partial Output}$,我们想知道loss对于某个输入的导函数是什么,也就是求$dInput_i = \frac{\partial loss}{\partial Input_i}$的问题,**我们要求如下:** +- 分析出来的求导表达式是一个或多个赋值语句形式,每个语句左侧的下标索引上不能有加减乘除等运算,也就是不能出现`A[i+1] = B[i]`的形式。 +- 必须通过对输入表达式的编译分析过程,综合出求导表达式的内容,并生成代码,不能通过判断case的名字直接得出求导表达式(这样就和传统框架一样了),也不能用打表法直接打印出字符串 #### 2.3一个例子 为了帮助理解,我们给一个例子: @@ -194,7 +196,6 @@ pdf评分标准: #### 4.4 审查法 审查代码是为了防止同学作弊,作弊的定义包含: - 拷贝或修改run2.h/run2.cc/clean2.cc的内容 -- 不使用任何编译分析技术直接输出字符串到kernels/目录下 - 任何两组的代码重合度过高甚至完全一致 - 完全使用第三方项目解决问题 - 报告内容与实际实现不一致 @@ -217,7 +218,44 @@ https://github.com/halide/Halide/blob/master/src/Derivative.cpp https://github.com/apache/incubator-tvm/pull/2498 4. 二维卷积反向传播推导 https://zhuanlan.zhihu.com/p/61898234 +5. 线性下标变换下求导方法 +https://arxiv.org/abs/1711.01348 ### 6. 讨论 Project可能潜在的bug可以在微信群、github issue上提出,有价值的issue可以为全组加分,每个bug加1分 -另外,鼓励小组内部协作与讨论,也鼓励适当的小组间交流,交流方式为github issue或微信群,助教也会参与讨论,解答一些技术问题。 \ No newline at end of file +另外,鼓励小组内部协作与讨论,也鼓励适当的小组间交流,交流方式为github issue或微信群,助教也会参与讨论,解答一些技术问题。 + +### 附录 +#### 1. IRMutator的使用 +IRMutator的功能是遍历IR,并且在遍历到每个节点的时候,返回一个新的IR节点。默认的IRMutator行为是返回和先前一摸一样的新节点。实际使用时,可以通过继承IRMutator,并重载特定的visit函数来定制对于IRMutator的遍历和修改行为。所有通过IRMutator对于AST的修改,都是创造新的AST,所以不会影响原来的AST的内容。 + +在test/目录下,ir_mutator.cc文件中展示了一个简单的定制Mutator的过程: +```c +class MyMutator : public IRMutator { + public: + Expr visit(Ref op) override { + if (op->name == "A") { + return Var::make(op->type(), "modified_A", op->args, op->shape); + } + return IRMutator::visit(op); + } +}; +``` +利用这个Mutator,可以把表达式里名字为"A"的Var节点更改为名字为"modified_A"的Var节点。 +```c +MyMutator mutator; +kernel = mutator.mutate(kernel); +``` +更改后的kernel,打印出来是这样的: +```py + simple_gemm(modified_A<1024, 256>, B<256, 512>, C<1024, 512>) { + for i in dom[((int32_t <1>) 0), ((int32_t <1>) 1024)){ + for j in dom[((int32_t <1>) 0), ((int32_t <1>) 512)){ + for k in dom[((int32_t <1>) 0), ((int32_t <1>) 256)){ + C[i, j] = C[i, j] + modified_A[i, k] * B[k, j] + } + } + } +} +``` +可以看到名字的确改了。这样,我们可以利用IRMutator实现很多不同的pass。 \ No newline at end of file