Skip to content

Commit

Permalink
allow extractor to only return embeddings, to ready for vision transf…
Browse files Browse the repository at this point in the history
…ormers to be used in x-clip
  • Loading branch information
lucidrains committed Dec 25, 2021
1 parent 0891885 commit e52ac41
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.25.5',
version = '0.25.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
25 changes: 22 additions & 3 deletions vit_pytorch/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ def exists(val):
return val is not None

class Extractor(nn.Module):
def __init__(self, vit, device = None):
def __init__(
self,
vit,
device = None,
layer_name = 'transformer',
return_embeddings_only = False
):
super().__init__()
self.vit = vit

Expand All @@ -16,11 +22,16 @@ def __init__(self, vit, device = None):
self.ejected = False
self.device = device

self.layer_name = layer_name
self.return_embeddings_only = return_embeddings_only

def _hook(self, _, input, output):
self.latents = output.clone().detach()

def _register_hook(self):
handle = self.vit.transformer.register_forward_hook(self._hook)
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
layer = getattr(self.vit, self.layer_name)
handle = layer.register_forward_hook(self._hook)
self.hooks.append(handle)
self.hook_registered = True

Expand All @@ -35,7 +46,11 @@ def clear(self):
del self.latents
self.latents = None

def forward(self, img):
def forward(
self,
img,
return_embeddings_only = False
):
assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
Expand All @@ -45,4 +60,8 @@ def forward(self, img):

target_device = self.device if exists(self.device) else img.device
latents = self.latents.to(target_device)

if return_embeddings_only or self.return_embeddings_only:
return latents

return pred, latents

0 comments on commit e52ac41

Please sign in to comment.