Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
luo3300612 committed Aug 9, 2021
0 parents commit 03382c8
Show file tree
Hide file tree
Showing 184 changed files with 37,251 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
examples
venv
model_test.py
.idea
113 changes: 113 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# README
Visualizer是一个辅助深度学习模型中Attention模块可视化的小工具,主要功能是帮助取出嵌套在模型深处的Attention Map

## 为什么需要Visualizer?
为了可视化Attention Map,你是否有以下苦恼
* 通过return将嵌套在模型深处的Attention Map一层层地返回回来,然后训练模型的时候又不得不还原
* 使用全局变量在Attention函数中直接记录Attention Map,结果训练的时候忘改回来导致OOM

不管你有没有,反正我有

咨询了**专业人士**的意见后,发现pytorch有个hook可以取出中间结果,大概查了一下,发现确实可以取出中间变量,但需要进行如下类似的hook注册
```python
handle = net.conv2.register_forward_hook(hook)
```
进行这样操作的前提是我们知道要取出来的模块名,但是Transformer类模型一般是这样定义的(以Vit为例)
```python
class VisionTransformer(nn.Module):
def __init__(self, *args, **kwargs):
...
self.blocks = nn.Sequential(*[Block(...) for i in range(depth)])
...
```
然后每个`Block`中都有一个Attention
```python
class Block(nn.Module):
def __init__(self, *args, **kwargs):
...
self.attn = Attention(...)
...
```
如果要使用hooks其中的问题就是
1. 我们根本不知道我们要取的attention map怎么以model.bla.bla.bla这样一直点出来!
2. transformer中attention map每层都有一个,一个个注册实在太麻烦了

所以我就思考并查找能否通过更简洁的方法来得到Attention Map(尤其是Transformer的),而visualizer就是其中的一种,它具有以下特点
* 精准直接,你可以取出任何变量名的模型中间结果
* 快捷方便,同时取出Transformer类模型中的所有attention map
* 非侵入式,你无须修改函数内的任何一行代码
* 训练-测试一致,可视化完成后,你无须在训练时再将代码改回来

## 用法
### 安装
```shell
python setup.py install
```

安装完成后,只需要用get_local装饰一下Attention的函数,forward之后就可以拿到函数内与装饰器参数同名的局部变量啦~
### Usage1
比如说,我想要函数里的`attention_map`变量:
```python
from visualizer import get_local
get_local.activate() # 激活visualizer

from ... import model

@get_local('attention_map')
def your_attention_function(*args, **kwargs):
...
attention_map = ...
...
return ...

# load model and data
...
out = model(data)

cache = get_local.cache # -> {'your_attention_function': [attention_map]}
```
最终就会以字典形式存在`get_local.cache`里,其中key是你的函数名,value就是一个存储attention_map的列表

### Usage2
使用Pytorch时我们往往会将模块定义成一个类,此时也是一样只要装饰类内计算出attention_map的函数即可
```python
from visualizer import get_local
get_local.activate() # 激活visualizer

class Attention(nn.Module):
def __init__(self):
...

@get_local('attn_map')
def forward(self, x):
...
attn_map = ...
...
return ...

# load model and data
...
out = model(data)

cache = get_local.cache # -> {'Attention.forward': [attention_map]}
```
其他细节请参考demo.ipynb文件
## 可视化结果
这里是可视化vit_small的结果,全部内容在demo.ipynb文件里
因为普通Vit所有Attention map都是在Attention.forward中计算出来的,所以只要简单地装饰一下这个函数,我们就可以同时取出vit中12层Transformer的所有Attention Map!


## 注意
* 想要可视化的变量在函数内部不能被后续的同名变量覆盖了,因为get_local取的是对应名称变量在函数中的**最终值**
* 进行可视化时,get_local.activate()一定要在导入模型****完成,因为python装饰器是在导入时执行的
* 训练时你不需要修改/删除任何代码,即不用删掉装饰函数的代码,因为在get_local.activate()没有执行的情况下,attention函数不会被装饰,故没有任何性能损失(同上一点,因为python装饰器是在导入时执行的)

## 其他
当然,其实get_local本身可以取出任何一个函数中某个局部变量的最终值,所以它应该还有其他更有趣的用途

## references
* [bytecode](https://blog.csdn.net/qfcy_/article/details/118890362)
* [local track1](https://stackoverflow.com/questions/52313851/how-can-i-track-the-values-of-a-local-variable-in-python)
* [local track2](https://stackoverflow.com/questions/19326004/access-a-function-variable-outside-the-function-without-using-global)
* [decorator1](https://stackoverflow.com/questions/1367514/how-to-decorate-a-method-inside-a-class)
* [decorator2](https://stackoverflow.com/questions/6676015/class-decorators-vs-function-decorators)
Binary file added assets/dogcat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/grids.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/head.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/heads.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/rabbit.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
716 changes: 716 additions & 0 deletions demo.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions imagenet_cls.json

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import setuptools

setuptools.setup(
name="visualizer",
version="0.0.1",
author="luo3300612",
author_email="[email protected]",
packages=setuptools.find_packages(),
)
4 changes: 4 additions & 0 deletions timm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .version import __version__
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
get_model_default_value, is_model_pretrained
12 changes: 12 additions & 0 deletions timm/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .config import resolve_data_config
from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .parsers import create_parser
from .real_labels import RealLabelsImagenet
from .transforms import *
from .transforms_factory import create_transform
Loading

0 comments on commit 03382c8

Please sign in to comment.