Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
KnowingNothing committed May 10, 2020
1 parent f8e7fec commit 5385396
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 5 deletions.
78 changes: 78 additions & 0 deletions test/ir_mutator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include <string>
#include <iostream>

#include "IR.h"
#include "IRMutator.h"
#include "IRVisitor.h"
#include "IRPrinter.h"
#include "type.h"

using namespace Boost::Internal;


class MyMutator : public IRMutator {
public:
Expr visit(Ref<const Var> op) override {
if (op->name == "A") {
return Var::make(op->type(), "modified_A", op->args, op->shape);
}
return IRMutator::visit(op);
}
};


int main() {
const int M = 1024;
const int N = 512;
const int K = 256;
Type index_type = Type::int_scalar(32);
Type data_type = Type::float_scalar(32);

// index i
Expr dom_i = Dom::make(index_type, 0, M);
Expr i = Index::make(index_type, "i", dom_i, IndexType::Spatial);

// index j
Expr dom_j = Dom::make(index_type, 0, N);
Expr j = Index::make(index_type, "j", dom_j, IndexType::Spatial);

// index k
Expr dom_k = Dom::make(index_type, 0, K);
Expr k = Index::make(index_type, "k", dom_k, IndexType::Reduce);

// A
Expr expr_A = Var::make(data_type, "A", {i, k}, {M, K});

// B
Expr expr_B = Var::make(data_type, "B", {k, j}, {K, N});

// C
Expr expr_C = Var::make(data_type, "C", {i, j}, {M, N});

// main stmt
Stmt main_stmt = Move::make(
expr_C,
Binary::make(data_type, BinaryOpType::Add, expr_C,
Binary::make(data_type, BinaryOpType::Mul, expr_A, expr_B)),
MoveType::MemToMem
);

// loop nest
Stmt loop_nest = LoopNest::make({i, j, k}, {main_stmt});

// kernel
Group kernel = Kernel::make("simple_gemm", {expr_A, expr_B}, {expr_C}, {loop_nest}, KernelType::CPU);

// mutator
MyMutator mutator;
kernel = mutator.mutate(kernel);

// printer
IRPrinter printer;
std::string code = printer.print(kernel);

std::cout << code;

std::cout << "Success!\n";
return 0;
}
4 changes: 2 additions & 2 deletions 编译大作业-第一部分.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ CList: 常量列表
AList: 变量列表
```
```bnf
P ::= P S| S
P ::= P S | S
S ::= LHS = RHS ;
LHS ::= TRef
RHS ::= RHS + RHS
Expand All @@ -33,7 +33,7 @@ TRef ::= Id < CList > [ AList ]
SRef ::= Id < CList >
CList ::= CList , IntV | IntV
AList ::= AList , IdExpr | IdExpr
IdExpr ::= Id | IdExpr + IdExpr | IdExpr + IntV | IdExpr * IntV | IdExpr // IntV | IdExpr % IntV | (IdExpr)
IdExpr ::= Id | IdExpr + IdExpr | IdExpr - IdExpr | IdExpr + IntV | IdExpr * IntV | IdExpr // IntV | IdExpr % IntV | (IdExpr)
Const ::= FloatV | IntV
```
另外,虽然文法允许张量名字出现在AList中,但是在语义上这是不允许的,所以在测试例子中不会出现这种输入。
Expand Down
44 changes: 41 additions & 3 deletions 编译大作业-第二部分.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ loss = mse_loss(Y2, T) # loss is scalar

#### 2.2 问题定义
__现在我们开始进行问题描述:__
对于一个给定的表达式$Output = expr(Input_1, Input_2, ..., Input_n)$($Output, Input_i$是张量或标量, $expr()$表示用其参数构造一个表达式),我们如果已知了最终loss对于$Output$的导数$dOutput = \frac{\partial loss}{\partial Output}$,我们想知道loss对于某个输入的导函数是什么,也就是求$dInput_i = \frac{\partial loss}{\partial Input_i}$的问题,**我们要求必须从表达式IR层面分析出求导的表达式,而不能根据case的名字硬编码答案**
对于一个给定的表达式$Output = expr(Input_1, Input_2, ..., Input_n)$($Output, Input_i$是张量或标量, $expr()$表示用其参数构造一个表达式),我们如果已知了最终loss对于$Output$的导数$dOutput = \frac{\partial loss}{\partial Output}$,我们想知道loss对于某个输入的导函数是什么,也就是求$dInput_i = \frac{\partial loss}{\partial Input_i}$的问题,**我们要求如下:**
- 分析出来的求导表达式是一个或多个赋值语句形式,每个语句左侧的下标索引上不能有加减乘除等运算,也就是不能出现`A[i+1] = B[i]`的形式。
- 必须通过对输入表达式的编译分析过程,综合出求导表达式的内容,并生成代码,不能通过判断case的名字直接得出求导表达式(这样就和传统框架一样了),也不能用打表法直接打印出字符串

#### 2.3一个例子
为了帮助理解,我们给一个例子:
Expand Down Expand Up @@ -194,7 +196,6 @@ pdf评分标准:
#### 4.4 审查法
审查代码是为了防止同学作弊,作弊的定义包含:
- 拷贝或修改run2.h/run2.cc/clean2.cc的内容
- 不使用任何编译分析技术直接输出字符串到kernels/目录下
- 任何两组的代码重合度过高甚至完全一致
- 完全使用第三方项目解决问题
- 报告内容与实际实现不一致
Expand All @@ -217,7 +218,44 @@ https://github.com/halide/Halide/blob/master/src/Derivative.cpp
https://github.com/apache/incubator-tvm/pull/2498
4. 二维卷积反向传播推导
https://zhuanlan.zhihu.com/p/61898234
5. 线性下标变换下求导方法
https://arxiv.org/abs/1711.01348

### 6. 讨论
Project可能潜在的bug可以在微信群、github issue上提出,有价值的issue可以为全组加分,每个bug加1分
另外,鼓励小组内部协作与讨论,也鼓励适当的小组间交流,交流方式为github issue或微信群,助教也会参与讨论,解答一些技术问题。
另外,鼓励小组内部协作与讨论,也鼓励适当的小组间交流,交流方式为github issue或微信群,助教也会参与讨论,解答一些技术问题。

### 附录
#### 1. IRMutator的使用
IRMutator的功能是遍历IR,并且在遍历到每个节点的时候,返回一个新的IR节点。默认的IRMutator行为是返回和先前一摸一样的新节点。实际使用时,可以通过继承IRMutator,并重载特定的visit函数来定制对于IRMutator的遍历和修改行为。所有通过IRMutator对于AST的修改,都是创造新的AST,所以不会影响原来的AST的内容。

在test/目录下,ir_mutator.cc文件中展示了一个简单的定制Mutator的过程:
```c
class MyMutator : public IRMutator {
public:
Expr visit(Ref<const Var> op) override {
if (op->name == "A") {
return Var::make(op->type(), "modified_A", op->args, op->shape);
}
return IRMutator::visit(op);
}
};
```
利用这个Mutator,可以把表达式里名字为"A"的Var节点更改为名字为"modified_A"的Var节点。
```c
MyMutator mutator;
kernel = mutator.mutate(kernel);
```
更改后的kernel,打印出来是这样的:
```py
<CPU> simple_gemm(modified_A<1024, 256>, B<256, 512>, C<1024, 512>) {
for i<spatial> in dom[((int32_t <1>) 0), ((int32_t <1>) 1024)){
for j<spatial> in dom[((int32_t <1>) 0), ((int32_t <1>) 512)){
for k<reduce> in dom[((int32_t <1>) 0), ((int32_t <1>) 256)){
C[i, j] =<mem_to_mem> C[i, j] + modified_A[i, k] * B[k, j]
}
}
}
}
```
可以看到名字的确改了。这样,我们可以利用IRMutator实现很多不同的pass

0 comments on commit 5385396

Please sign in to comment.