Skip to content

Commit

Permalink
make extractor flexible for layers that output multiple tensors, show…
Browse files Browse the repository at this point in the history
… CrossViT example
  • Loading branch information
lucidrains committed Jun 19, 2022
1 parent b3e90a2 commit 4e62e5f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 4 deletions.
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,47 @@ logits, embeddings = v(img)
embeddings # (1, 65, 1024) - (batch x patches x model dim)
```

Or say for `CrossViT`, which has a multi-scale encoder that outputs two sets of embeddings for 'large' and 'small' scales

```python
import torch
from vit_pytorch.cross_vit import CrossViT

v = CrossViT(
image_size = 256,
num_classes = 1000,
depth = 4,
sm_dim = 192,
sm_patch_size = 16,
sm_enc_depth = 2,
sm_enc_heads = 8,
sm_enc_mlp_dim = 2048,
lg_dim = 384,
lg_patch_size = 64,
lg_enc_depth = 3,
lg_enc_heads = 8,
lg_enc_mlp_dim = 2048,
cross_attn_depth = 2,
cross_attn_heads = 8,
dropout = 0.1,
emb_dropout = 0.1
)

# wrap the CrossViT

from vit_pytorch.extractor import Extractor
v = Extractor(v, layer_name = 'multi_scale_encoder') # take embedding coming from the output of multi-scale-encoder

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 256)
logits, embeddings = v(img)

# there is one extra token due to the CLS token

embeddings # ((1, 257, 192), (1, 17, 384)) - (batch x patches x dimension) <- large and small scales respectively
```

## Research Ideas

### Efficient Attention
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.35.2',
version = '0.35.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = '[email protected]',
url = 'https://github.com/lucidrains/vit-pytorch',
Expand Down
11 changes: 8 additions & 3 deletions vit_pytorch/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
def exists(val):
return val is not None

def apply_tuple_or_single(fn, val):
if isinstance(val, tuple):
return tuple(map(fn, val))
return fn(val)

class Extractor(nn.Module):
def __init__(
self,
Expand All @@ -28,8 +33,8 @@ def __init__(
self.return_embeddings_only = return_embeddings_only

def _hook(self, _, inputs, output):
tensor_to_save = inputs if self.layer_save_input else output
self.latents = tensor_to_save.clone().detach()
layer_output = inputs if self.layer_save_input else output
self.latents = apply_tuple_or_single(lambda t: t.clone().detach(), layer_output)

def _register_hook(self):
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
Expand Down Expand Up @@ -62,7 +67,7 @@ def forward(
pred = self.vit(img)

target_device = self.device if exists(self.device) else img.device
latents = self.latents.to(target_device)
latents = apply_tuple_or_single(lambda t: t.to(target_device), self.latents)

if return_embeddings_only or self.return_embeddings_only:
return latents
Expand Down

0 comments on commit 4e62e5f

Please sign in to comment.