forked from lucidrains/vit-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrecorder.py
59 lines (46 loc) · 1.73 KB
/
recorder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from functools import wraps
import torch
from torch import nn
from vit_pytorch.vit import Attention
def find_modules(nn_module, type):
return [module for module in nn_module.modules() if isinstance(module, type)]
class Recorder(nn.Module):
def __init__(self, vit, device = None):
super().__init__()
self.vit = vit
self.data = None
self.recordings = []
self.hooks = []
self.hook_registered = False
self.ejected = False
self.device = device
def _hook(self, _, input, output):
self.recordings.append(output.clone().detach())
def _register_hook(self):
modules = find_modules(self.vit.transformer, Attention)
for module in modules:
handle = module.attend.register_forward_hook(self._hook)
self.hooks.append(handle)
self.hook_registered = True
def eject(self):
self.ejected = True
for hook in self.hooks:
hook.remove()
self.hooks.clear()
return self.vit
def clear(self):
self.recordings.clear()
def record(self, attn):
recording = attn.clone().detach()
self.recordings.append(recording)
def forward(self, img):
assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()
pred = self.vit(img)
# move all recordings to one device before stacking
target_device = self.device if self.device is not None else img.device
recordings = tuple(map(lambda t: t.to(target_device), self.recordings))
attns = torch.stack(recordings, dim = 1) if len(recordings) > 0 else None
return pred, attns