Skip to content

Commit

Permalink
fix AliBi bias caching for varying input shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya16 committed May 18, 2022
1 parent df89141 commit 74f241d
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __init__(self, heads, **kwargs):
def get_bias(self, i, j, device):
i_arange = torch.arange(i, device = device)
j_arange = torch.arange(j, device = device)
bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1'))
bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
return bias

@staticmethod
Expand All @@ -270,15 +270,15 @@ def forward(self, qk_dots):
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device

if exists(self.bias) and self.bias.shape[-1] >= j:
return qk_dots + self.bias[..., :j]

if not exists(self.bias):
bias = self.get_bias(i, j, device)
bias = bias * self.slopes
return qk_dots + self.bias[..., :i, :j]

bias = self.get_bias(i, j, device)
bias = bias * self.slopes

num_heads_unalibied = h - bias.shape[0]
bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
self.register_buffer('bias', bias, persistent=False)

num_heads_unalibied = h - bias.shape[-3]
bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
self.register_buffer('bias', bias, persistent = False)
return qk_dots + self.bias

class LearnedAlibiPositionalBias(AlibiPositionalBias):
Expand All @@ -293,13 +293,12 @@ def forward(self, qk_dots):
def get_slopes(param):
return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[0]))

if not exists(self.bias):
bias = self.get_bias(i, j, device)
self.register_buffer('bias', bias, persistent = False)
if self.bias.shape[-1] >= j:
if exists(self.bias) and self.bias.shape[-1] >= j:
bias = self.bias[..., :i, :j]
else:
bias = self.bias
bias = self.get_bias(i, j, device)
self.register_buffer('bias', bias, persistent=False)

slopes = get_slopes(self.learned_logslopes)
bias = bias * slopes

Expand Down

0 comments on commit 74f241d

Please sign in to comment.