Skip to content

Commit

Permalink
add the tensor product attention, make 1.0 since why not
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 15, 2022
1 parent 7df4720 commit 9866ca5
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1580,4 +1580,13 @@ generated = model.generate(start_emb, 17) # (17, 777)
}
```

```bibtex
@misc{schlag2020enhancing,
title = {Enhancing the Transformer with explicit relational encoding for math problem solving},
author = {Imanol Schlag and Paul Smolensky and Roland Fernandez and Nebojsa Jojic and J{\"u}rgen Schmidhuber and Jianfeng Gao},
year = {2020},
url = {https://openreview.net/forum?id=B1xfElrKPr}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.33.1',
version = '1.0.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
28 changes: 23 additions & 5 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from functools import partial
from functools import partial, wraps
from inspect import isfunction
from collections import namedtuple

Expand Down Expand Up @@ -40,6 +40,14 @@ def default(val, d):
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth

def maybe(fn):
@wraps(fn)
def inner(x, *args, **kwargs):
if not exists(x):
return x
return fn(x, *args, **kwargs)
return inner

class always():
def __init__(self, val):
self.val = val
Expand Down Expand Up @@ -527,7 +535,8 @@ def __init__(
qk_norm_scale = 1,
one_kv_head = False,
shared_kv = False,
value_dim_head = None
value_dim_head = None,
tensor_product = False # https://arxiv.org/abs/2208.06061
):
super().__init__()
self.scale = dim_head ** -0.5
Expand All @@ -553,6 +562,9 @@ def __init__(
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None

# relations projection from tp-attention
self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None

# dropout
self.dropout = nn.Dropout(dropout)

Expand Down Expand Up @@ -636,10 +648,12 @@ def forward(
k = self.to_k(k_input)
v = self.to_v(v_input) if exists(self.to_v) else k

r = self.to_r(v_input) if exists(self.to_r) else None

q = rearrange(q, 'b n (h d) -> b h n d', h = h)

if not self.one_kv_head:
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (k, v))
k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = h), (k, v, r))

if exists(rotary_pos_emb) and not has_context:
l = rotary_pos_emb.shape[-1]
Expand Down Expand Up @@ -708,8 +722,8 @@ def forward(

if self.causal:
i, j = dots.shape[-2:]
r = torch.arange(i, device = device)
mask = rearrange(r, 'i -> 1 1 i 1') < rearrange(r, 'j -> 1 1 1 j')
range_i = torch.arange(i, device = device)
mask = rearrange(range_i, 'i -> 1 1 i 1') < rearrange(range_i, 'j -> 1 1 1 j')
mask = F.pad(mask, (j - i, 0), value = False)
dots.masked_fill_(mask, mask_value)
del mask
Expand All @@ -731,6 +745,10 @@ def forward(

out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)

if exists(r):
# https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
out = out * r + out

if head_scale:
out = out * self.head_scale_params

Expand Down

0 comments on commit 9866ca5

Please sign in to comment.