Skip to content

Commit

Permalink
add extractor wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 21, 2021
1 parent b983bbe commit 2c368d1
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 1 deletion.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,41 @@ to cleanup the class and the hooks once you have collected enough data
v = v.eject() # wrapper is discarded and original ViT instance is returned
```

## Accessing Embeddings

You can similarly access the embeddings with the `Extractor` wrapper

```python
import torch
from vit_pytorch.vit import ViT

v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)

# import Recorder and wrap the ViT

from vit_pytorch.extractor import Extractor
v = Extractor(v)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 256)
logits, embeddings = v(img)

# there is one extra token due to the CLS token

embeddings # (1, 65, 1024) - (batch x patches x model dim)
```

## Research Ideas

### Efficient Attention
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.25.0',
version = '0.25.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
48 changes: 48 additions & 0 deletions vit_pytorch/extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from torch import nn

def exists(val):
return val is not None

class Extractor(nn.Module):
def __init__(self, vit, device = None):
super().__init__()
self.vit = vit

self.data = None
self.latents = None
self.hooks = []
self.hook_registered = False
self.ejected = False
self.device = device

def _hook(self, _, input, output):
self.latents = output.clone().detach()

def _register_hook(self):
handle = self.vit.transformer.register_forward_hook(self._hook)
self.hooks.append(handle)
self.hook_registered = True

def eject(self):
self.ejected = True
for hook in self.hooks:
hook.remove()
self.hooks.clear()
return self.vit

def clear(self):
del self.latents
self.latents = None

def forward(self, img):
assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()

pred = self.vit(img)

target_device = self.device if exists(self.device) else img.device
latents = self.latents.to(target_device)
return pred, latents

0 comments on commit 2c368d1

Please sign in to comment.