Skip to content

Commit

Permalink
add layerscale, researcher friend says it works well, also bump torch…
Browse files Browse the repository at this point in the history
… and einops dep
  • Loading branch information
lucidrains committed May 11, 2024
1 parent cfbadea commit a11e039
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 13 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2119,6 +2119,17 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
}
```

```bibtex
@article{Wang2022DeepNetST,
title = {DeepNet: Scaling Transformers to 1, 000 Layers},
author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei},
journal = {ArXiv},
year = {2022},
volume = {abs/2203.00555},
url = {https://api.semanticscholar.org/CorpusID:247187905}
}
```

```bibtex
@article{Rafailov2023DirectPO,
title = {Direct Preference Optimization: Your Language Model is Secretly a Reward Model},
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.29.2',
version = '1.30.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand All @@ -16,8 +16,8 @@
'transformers'
],
install_requires=[
'torch>=1.6',
'einops>=0.7.0'
'torch>=2.0',
'einops>=0.8.0'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
4 changes: 2 additions & 2 deletions x_transformers/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ def forward(

if return_mems:
hiddens = intermediates.hiddens
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens)
return out, new_mems

if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates)
return out, attn_maps

return out
Expand Down
45 changes: 39 additions & 6 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def init_zero_(layer):
# keyword argument helpers

def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
values = tuple(d.pop(key) for key in keys)
return dict(zip(keys, values))

def group_dict_by_key(cond, d):
Expand All @@ -151,7 +151,7 @@ def group_dict_by_key(cond, d):
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
return tuple(return_val)

def string_begins_with(prefix, str):
return str.startswith(prefix)
Expand All @@ -161,7 +161,8 @@ def group_by_key_prefix(prefix, d):

def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
prefix_len = len(prefix)
kwargs_without_prefix = {key[prefix_len:]: value for key, value in kwargs_with_prefix.items()}
return kwargs_without_prefix, kwargs

# structured dropout, more effective than traditional attention dropouts
Expand Down Expand Up @@ -457,7 +458,6 @@ def forward(self, t):

return freqs, scale


def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j = 2)
x1, x2 = x.unbind(dim = -2)
Expand Down Expand Up @@ -574,8 +574,8 @@ def forward(self, x, residual):
def shift(t, amount, mask = None):
if amount == 0:
return t
else:
amount = min(amount, t.shape[1])

amount = min(amount, t.shape[1])

if exists(mask):
t = t.masked_fill(~mask[..., None], 0.)
Expand All @@ -599,6 +599,23 @@ def forward(self, x, **kwargs):
x = torch.cat((*segments_to_shift, *rest), dim = -1)
return self.fn(x, **kwargs)

# post branch operator

class LayerScale(Module):
def __init__(self, fn: Module, dim, init_value = 0.):
super().__init__()
self.fn = fn
self.gamma = nn.Parameter(torch.ones(dim) * init_value)

def forward(self, x, **kwargs):
out = self.fn(x, **kwargs)

if isinstance(out, Tensor):
return out * self.gamma

out, *rest = out
return out * self.gamma, *rest

# feedforward

class GLU(Module):
Expand Down Expand Up @@ -1047,6 +1064,8 @@ def __init__(
layer_dropout = 0.,
cross_attn_tokens_dropout = 0.,
disable_abs_pos_emb = None,
use_layerscale = False,
layerscale_init_value = 0.,
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -1110,6 +1129,8 @@ def __init__(

self.cross_attend = cross_attend

# determine norm

assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'

if use_scalenorm:
Expand All @@ -1123,6 +1144,8 @@ def __init__(

norm_fn = partial(norm_class, dim)

# determine default block layer type order

if cross_attend and not only_cross:
default_block = ('a', 'c', 'f')
elif cross_attend and only_cross:
Expand All @@ -1133,6 +1156,13 @@ def __init__(
if macaron:
default_block = ('f',) + default_block

# determine post branch wrapper

post_branch_fn = None

if use_layerscale:
post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)

# zero init

if zero_init_branch_output:
Expand Down Expand Up @@ -1221,6 +1251,9 @@ def __init__(
shift_range_lower = -layer_shift_tokens if not causal else 0
layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)

if exists(post_branch_fn):
layer = post_branch_fn(layer)

residual_fn = GRUGating if gate_residual else Residual
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)

Expand Down
4 changes: 2 additions & 2 deletions x_transformers/xval.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ def forward(

if return_mems:
hiddens = intermediates.hiddens
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens)
return out, new_mems

if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates)
return out, attn_maps

return out
Expand Down

0 comments on commit a11e039

Please sign in to comment.