Skip to content

Commit

Permalink
remove cascading heads
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 3, 2023
1 parent 37b351b commit 4f9775b
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 94 deletions.
10 changes: 0 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1932,16 +1932,6 @@ generated = model.generate(start_emb, 17) # (17, 777)
}
```

```bibtex
@article{Liu2023EfficientViTME,
title = {EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention},
author = {Xinyu Liu and Houwen Peng and Ningxin Zheng and Yuqing Yang and Han Hu and Yixuan Yuan},
journal = {ArXiv},
year = {2023},
volume = {abs/2305.07027}
}
```

```bibtex
@article{Kazemnejad2023TheIO,
title = {The Impact of Positional Encoding on Length Generalization in Transformers},
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.23.0',
version = '1.23.1',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
76 changes: 0 additions & 76 deletions x_transformers/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,79 +346,3 @@ def forward(
)

return out, intermediates

# cascading heads logic

def to_single_heads(t, dim = 1):
heads = t.unbind(dim = dim)
return tuple(head.unsqueeze(dim) for head in heads)

class CascadingHeads(nn.Module):
def __init__(self, attend: Attend):
super().__init__()
self.attend = attend

def forward(
self,
q, k, v,
mask = None,
attn_bias = None,
prev_attn = None
):
assert q.shape[-1] == v.shape[-1], 'cascading heads can only be done if query / key and value head dimensions are the same'

# split inputs into per-head inputs

heads = q.shape[1]

queries = to_single_heads(q)
keys = to_single_heads(k) if k.ndim == 4 else ((k,) * heads)
values = to_single_heads(v) if v.ndim == 4 else ((v,) * heads)

mask = (mask,) * heads

attn_bias = to_single_heads(attn_bias, dim = 0) if exists(attn_bias) else ((None,) * heads)
prev_attn = to_single_heads(prev_attn) if exists(prev_attn) else ((None,) * heads)

# now loop through each head, without output of previous head summed with the next head
# thus cascading

all_outs = []
all_intermediates = []

prev_head_out = None

for h_q, h_k, h_v, h_mask, h_attn_bias, h_prev_attn in zip(queries, keys, values, mask, attn_bias, prev_attn):

if exists(prev_head_out):
h_q = h_q + prev_head_out

out, intermediates = self.attend(
h_q, h_k, h_v,
mask = h_mask,
attn_bias = h_attn_bias,
prev_attn = h_prev_attn
)

prev_head_out = out

all_outs.append(out)
all_intermediates.append(intermediates)

# cat all output heads

all_outs = torch.cat(all_outs, dim = 1)

# cat all intermediates, if they exist

qk_similarities, pre_softmax_attn, post_softmax_attn = zip(*map(lambda i: i.to_tuple(), all_intermediates))

qk_similarities, pre_softmax_attn, post_softmax_attn = map(compact, (qk_similarities, pre_softmax_attn, post_softmax_attn))

aggregated_intermediates = Intermediates(
qk_similarities = torch.cat(qk_similarities, dim = 1) if len(qk_similarities) > 0 else None,
pre_softmax_attn = torch.cat(pre_softmax_attn, dim = 1) if len(pre_softmax_attn) > 0 else None,
post_softmax_attn = torch.cat(post_softmax_attn, dim = 1) if len(post_softmax_attn) > 0 else None
)

return all_outs, aggregated_intermediates
8 changes: 1 addition & 7 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange

from x_transformers.attend import Attend, Intermediates, CascadingHeads
from x_transformers.attend import Attend, Intermediates
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

# constants
Expand Down Expand Up @@ -662,7 +662,6 @@ def __init__(
shared_kv = False,
value_dim_head = None,
tensor_product = False, # https://arxiv.org/abs/2208.06061
cascading_heads = False,
add_zero_kv = False, # same as add_zero_attn in pytorch
rotary_embed_values = False,
onnxable = False
Expand All @@ -674,7 +673,6 @@ def __init__(
self.causal = causal
self.max_attend_past = max_attend_past


assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'

value_dim_head = default(value_dim_head, dim_head)
Expand Down Expand Up @@ -738,10 +736,6 @@ def __init__(
onnxable = onnxable
)

if cascading_heads:
# cascading heads - wrap the Attend logic
self.attend = CascadingHeads(self.attend)

# head scaling
self.head_scale = head_scale
if head_scale:
Expand Down

0 comments on commit 4f9775b

Please sign in to comment.