Skip to content

Commit

Permalink
Customizable Embedder and Logit Mapper (lucidrains#288)
Browse files Browse the repository at this point in the history
* Allow kwargs in TokenEmbedding's forward method

* Move weight init to embedder class

* Allow custom logits mappers

* Add unit tests for embedder and logits mapper

* Revert "Allow custom logits mappers"

This reverts commit e8166b4.

* Use kwargs for token_emb and to_logits

* Revert to version from main

* Add tests for custom token_emb and to_logits

* Undo accidental deletion of **kwargs

* Simplify test_to_logits

* Remove unused import
  • Loading branch information
pradeep-pyro authored Nov 7, 2024
1 parent 7e15c09 commit 409ba0f
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 11 deletions.
82 changes: 81 additions & 1 deletion tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
TransformerWrapper,
Encoder,
Decoder,
AutoregressiveWrapper,
LinearNoBias,
)

from x_transformers.neo_mlp import (
Expand Down Expand Up @@ -394,3 +394,83 @@ def test_custom_alibi():
pos = torch.tensor([[0, 1, 2, 4], [1, 3, 5, 7]])

logits = model(x, pos = pos)


@pytest.mark.parametrize('embedder_type', ('embedding', 'none', 'custom'))
def test_embedder(embedder_type):
num_tokens = 20000
dim = 128
token_emb_kwargs = {}
if embedder_type == 'embedding':
embedder = torch.nn.Embedding(num_tokens, dim)
elif embedder_type == 'none':
embedder = None
else:
class CustomEmbedder(torch.nn.Module):
"""
Made up embedder that sums two embeddings. Just to check if we can pass additional input to the embedder's
forward pass without breaking the model.
"""
def __init__(self, num_tokens, dim):
super().__init__()
self.embed_x = torch.nn.Embedding(num_tokens, dim)
self.embed_y = torch.nn.Embedding(num_tokens, dim)
def forward(self, x, y):
return self.embed_x(x) + self.embed_y(y)
def init_(self):
pass
embedder = CustomEmbedder(num_tokens, dim)
token_emb_kwargs['y'] = torch.randint(0, num_tokens, (2, 1024))
model = TransformerWrapper(
num_tokens = num_tokens,
max_seq_len = 1024,
attn_layers = Decoder(
dim = dim,
depth = 6,
heads = 8,
),
token_emb = embedder,
)

x = torch.randint(0, 20000, (2, 1024))

output = model(x, token_emb_kwargs=token_emb_kwargs)
assert output.shape == (2, 1024, 20000)


@pytest.mark.parametrize("to_logits", ('linear', 'none', 'pointer'))
def test_to_logits(to_logits):
num_tokens = 20000
dim = 128
to_logits_kwargs = {}
if to_logits == 'linear':
logit_mapper = LinearNoBias(dim, num_tokens)
elif to_logits == 'none':
logit_mapper = None
else:
class PointerNetworkLogits(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.proj_to_pointers = torch.nn.Linear(dim, dim)
def forward(self, model_embeddings, input_embeddings):
pointers = self.proj_to_pointers(model_embeddings)
logits = torch.matmul(pointers, input_embeddings.permute(0, 2, 1))
return logits

logit_mapper = PointerNetworkLogits(dim)
to_logits_kwargs['input_embeddings'] = torch.randn(2, 20000, dim)

model = TransformerWrapper(
num_tokens = num_tokens,
max_seq_len = 1024,
attn_layers = Decoder(
dim = dim,
depth = 6,
heads = 8,
),
to_logits = logit_mapper,
)

x = torch.randint(0, num_tokens, (2, 1024))
output = model(x, to_logits_kwargs=to_logits_kwargs)
assert output.shape == (2, 1024, 20000)
29 changes: 19 additions & 10 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ def forward(self, x):
token_emb = self.emb(x.long())
return l2norm(token_emb) if self.l2norm_embed else token_emb

def init_(self):
if self.l2norm_embed:
nn.init.normal_(self.emb.weight, std=1e-5)
return
nn.init.kaiming_normal_(self.emb.weight)


# positional embeddings

class AbsolutePositionalEmbedding(Module):
Expand Down Expand Up @@ -2261,7 +2268,8 @@ def __init__(
token_emb: TokenEmbedding | None = None,
mixture_of_softmax = False,
mixture_of_softmax_k = 4,
sigsoftmax_logits = False
sigsoftmax_logits = False,
to_logits: Module | None = None,
):
super().__init__()

Expand Down Expand Up @@ -2363,11 +2371,12 @@ def __init__(
if return_only_embed:
self.to_logits = None
elif tie_embedding:
assert isinstance(self.token_emb, TokenEmbedding), 'can only tie embedding if using `TokenEmbedding`'
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
elif num_output_heads > 1:
self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
else:
self.to_logits = LinearNoBias(dim, logits_dim)
self.to_logits = LinearNoBias(dim, logits_dim) if to_logits is None else to_logits

# memory tokens (like [cls]) from Memory Transformers paper

Expand All @@ -2388,13 +2397,11 @@ def __init__(
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb

def init_(self):
if hasattr(self.token_emb, 'init_'):
self.token_emb.init_()
if self.l2norm_embed:
nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
if not isinstance(self.pos_emb, always):
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
return

nn.init.kaiming_normal_(self.token_emb.emb.weight)

def forward(
self,
Expand All @@ -2417,7 +2424,9 @@ def forward(
attn_z_loss_weight = 1e-4,
seq_start_pos = None,
cache: LayerIntermediates | None = None,
**kwargs
token_emb_kwargs = dict(),
to_logits_kwargs = dict(),
**kwargs,
):
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask

Expand All @@ -2428,7 +2437,7 @@ def forward(

external_pos_emb = exists(pos) and pos.dtype != torch.long
pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
x = self.token_emb(x) + pos_emb
x = self.token_emb(x, **token_emb_kwargs) + pos_emb

# add additional embeddings

Expand Down Expand Up @@ -2583,9 +2592,9 @@ def forward(

if not return_embeddings:
if self.has_multiple_heads:
logits = tuple(fn(x) for fn in self.to_logits)
logits = tuple(fn(x, **to_logits_kwargs) for fn in self.to_logits)
else:
logits = self.to_logits(x)
logits = self.to_logits(x, **to_logits_kwargs)

# maybe sig softmax

Expand Down

0 comments on commit 409ba0f

Please sign in to comment.