Skip to content

Commit

Permalink
allow continuous transformer wrapper to have prepend embedding mask too
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 27, 2023
1 parent 13ac0cb commit 1f0232f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.25.11',
version = '1.25.15',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
14 changes: 12 additions & 2 deletions x_transformers/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ def forward(
mems = None,
pos = None,
prepend_embeds = None,
prepend_mask = None,
**kwargs
):
batch = x.shape[0]
batch, seq, device = *x.shape[:2], x.device

x = self.project_in(x)
x = x + self.pos_emb(x, pos = pos)
Expand All @@ -107,11 +108,18 @@ def forward(
# whether to append embeds, as in PaLI, for image embeddings

if exists(prepend_embeds):
_, prepend_dim = prepend_embeds.shape[1:]
prepend_seq, prepend_dim = prepend_embeds.shape[1:]

assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'

x = torch.cat((prepend_embeds, x), dim = -2)

if exists(prepend_mask) or exists(mask):
mask = default(mask, lambda: torch.ones((batch, seq), device = device, dtype = torch.bool))
prepend_mask = default(prepend_mask, lambda: torch.ones((batch, prepend_seq), device = device, dtype = torch.bool))

mask = torch.cat((prepend_mask, mask), dim = -1)

x = self.emb_dropout(x)

# attention layers
Expand Down Expand Up @@ -186,6 +194,8 @@ def generate(self, start_tokens, seq_len, **kwargs):
def forward(self, x, **kwargs):
inp, target = x[:, :-1], x[:, 1:]

assert 'prepend_embeds' not in kwargs

mask = kwargs.get('mask', None)
if exists(mask) and mask.shape[1] == x.shape[1]:
mask = mask[:, :-1]
Expand Down

0 comments on commit 1f0232f

Please sign in to comment.