Skip to content

Commit

Permalink
[GPTNeo] create local attention mask ones (huggingface#11335)
Browse files Browse the repository at this point in the history
* create local attention mask ones

* remove old method, address patricks comment
  • Loading branch information
patil-suraj authored Apr 20, 2021
1 parent f464f10 commit cfd2eaa
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 65 deletions.
128 changes: 69 additions & 59 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,57 @@ def _look_back(tensor, block_length, window_size, pad_value=0, is_key_value=True
padded_tensor = padded_tensor.transpose(-2, -1)
return padded_tensor

@staticmethod
def _split_seq_length_dim_to(tensors, dim_factor_1, dim_factor_2):
"""
Splits sequence length dim of tensors into `dim_factor_1` and `dim_factor_2` dims
"""
batch_size = tensors.shape[0]
split_dim_shape = (batch_size, dim_factor_1, dim_factor_2)

if len(tensors.shape) == 3:
return torch.reshape(tensors, split_dim_shape + (-1,))
elif len(tensors.shape) == 2:
return torch.reshape(tensors, split_dim_shape)
else:
raise ValueError(f"Input vector rank should be one of [2, 3], but is: {len(tensors.shape)}")

@staticmethod
def create_local_attention_mask(batch_size, seq_length, window_size, device, attention_mask=None):
block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
indices = torch.arange(seq_length, dtype=torch.long, device=device).repeat(batch_size, 1)

query_indices = GPTNeoAttentionMixin._split_seq_length_dim_to(indices, num_blocks, block_length)
key_indices = GPTNeoAttentionMixin._look_back(indices, block_length, window_size, is_key_value=False)

# create mask tensor such that each block contains a causal_mask for that block
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2))

if attention_mask is None:
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device)

# A block can also be padded becuase of the _look_back operation
# look back into the attention_block such that it will also get padded the same way
# and have 0s in the padded position
attention_mask = GPTNeoAttentionMixin._look_back(attention_mask, block_length, window_size, is_key_value=False)
attention_mask = attention_mask.unsqueeze(-2) # Add an extra dimension to account for hidden_dim

# Multiply the causal_mask with attention_mask so the padded positions (by _look_back operation)
# will contain 0s.
# This also makes sure that other positions ignored by the attention_mask will also be ignored
# in the causal_mask.
causal_mask = causal_mask * attention_mask

# In GPT Neo's local attention each window can attend to at most window_size tokens
# rest of the tokens should be ignored.
relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1)
visible = torch.gt(relative_position, -window_size)

causal_mask = causal_mask * visible
causal_mask = causal_mask.unsqueeze(-3).bool() # Add an extra dimension to account for num_heads

return causal_mask

def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
Expand All @@ -218,20 +269,6 @@ def _merge_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)

def _split_seq_length_dim_to(self, tensors, dim_factor_1, dim_factor_2, hidden_size):
"""
Splits sequence length dim of tensors into `dim_factor_1` and `dim_factor_2` dims
"""
batch_size = tensors.shape[0]
split_dim_shape = (batch_size, dim_factor_1, dim_factor_2)

if len(tensors.shape) == 3:
return torch.reshape(tensors, split_dim_shape + (hidden_size,))
elif len(tensors.shape) == 2:
return torch.reshape(tensors, split_dim_shape)
else:
raise ValueError(f"Input vector rank should be one of [2, 3], but is: {len(tensors.shape)}")

def _attn(self, query, key, value, causal_mask, masked_bias, attn_dropout, attention_mask=None, head_mask=None):
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
Expand Down Expand Up @@ -289,8 +326,8 @@ def __init__(self, config):
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
layer_past=None,
head_mask=None,
use_cache=False,
output_attentions=False,
Expand Down Expand Up @@ -357,45 +394,11 @@ def __init__(self, config):

self.window_size = config.window_size

def _create_attention_mask(self, batch_size, seq_length, num_blocks, block_length, device, attention_mask=None):
indices = torch.arange(seq_length, dtype=torch.long, device=device).repeat(batch_size, 1)

query_indices = self._split_seq_length_dim_to(indices, num_blocks, block_length, self.embed_dim)
key_indices = self._look_back(indices, block_length, self.window_size, is_key_value=False)

# create mask tensor such that each block contains a causal_mask for that block
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2))

if attention_mask is None:
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device)

# A block can also be padded becuase of the _look_back operation
# look back into the attention_block such that it will also get padded the same way
# and have 0s in the padded position
attention_mask = self._look_back(attention_mask, block_length, self.window_size, is_key_value=False)
attention_mask = attention_mask.unsqueeze(-2) # Add an extra dimension to account for hidden_dim

# Multiply the causal_mask with attention_mask so the padded positions (by _look_back operation)
# will contain 0s.
# This also makes sure that other positions ignored by the attention_mask will also be ignored
# in the causal_mask.
causal_mask = causal_mask * attention_mask

# In GPT Neo's local attention each window can attend to at most window_size tokens
# rest of the tokens should be ignored.
relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1)
visible = torch.gt(relative_position, -self.window_size)

causal_mask = causal_mask * visible
causal_mask = causal_mask.unsqueeze(-3).bool() # Add an extra dimension to account for num_heads

return causal_mask

def forward(
self,
hidden_states,
attention_mask,
layer_past=None,
attention_mask=None,
head_mask=None,
use_cache=False,
output_attentions=False,
Expand All @@ -421,9 +424,9 @@ def forward(
# create buckets
if layer_past is not None:
# we just need 1 block with block_length 1 when caching is enabled
query = self._split_seq_length_dim_to(query, 1, 1, self.embed_dim)
query = self._split_seq_length_dim_to(query, 1, 1)
else:
query = self._split_seq_length_dim_to(query, num_blocks, block_length, self.embed_dim)
query = self._split_seq_length_dim_to(query, num_blocks, block_length)

key = self._look_back(key, block_length, self.window_size)
value = self._look_back(value, block_length, self.window_size)
Expand All @@ -437,18 +440,16 @@ def forward(
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

mask = self._create_attention_mask(
batch_size, full_seq_length, num_blocks, block_length, hidden_states.device, attention_mask
)
if layer_past is not None:
mask = mask[:, -1:, :, -1:, :] # only take the mask for the last block
# only take the mask for the last block
attention_mask = attention_mask[:, -1:, :, -1:, :]

# attn
attn_output, attn_weights = self._attn(
query,
key,
value,
causal_mask=mask,
causal_mask=attention_mask,
masked_bias=self.masked_bias,
attn_dropout=self.attn_dropout,
head_mask=head_mask,
Expand Down Expand Up @@ -495,8 +496,8 @@ def forward(
):
outputs = self.attention(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down Expand Up @@ -767,6 +768,8 @@ def forward(
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)

device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
Expand All @@ -792,6 +795,13 @@ def forward(
else:
global_attention_mask = None

# Local causal attention mask
batch_size, seq_length = input_shape
full_seq_length = seq_length + past_length
local_attention_mask = GPTNeoAttentionMixin.create_local_attention_mask(
batch_size, full_seq_length, self.config.window_size, device, attention_mask
)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_headss x N x N
Expand All @@ -816,7 +826,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
attn_type = self.config.attention_layers[i]
attn_mask = global_attention_mask if attn_type == "global" else attention_mask
attn_mask = global_attention_mask if attn_type == "global" else local_attention_mask

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down
20 changes: 14 additions & 6 deletions tests/test_modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
GPTNeoForCausalLM,
GPTNeoModel,
)
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin, GPTNeoLocalSelfAttention
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin


class GPTNeoModelTester:
Expand Down Expand Up @@ -497,12 +497,14 @@ def test_look_back(self):

def test_create_attention_mask(self):
config = GPTNeoConfig.from_pretrained("valhalla/gpt-neo-random-tiny")
layer = GPTNeoLocalSelfAttention(config)
window_size = config.window_size
batch_size, seq_length = 8, 1
block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)

causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device)
# causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device)
causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
batch_size, seq_length, config.window_size, torch_device
)
# check shapes
expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length]
self.assertListEqual(list(causal_mask.shape), expected_shape)
Expand All @@ -516,8 +518,11 @@ def test_create_attention_mask(self):
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=torch_device)
attention_mask[:, -3:] = 0 # don't attend last 3 tokens

causal_mask = layer._create_attention_mask(
batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask
# causal_mask = layer._create_attention_mask(
# batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask
# )
causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
batch_size, seq_length, config.window_size, torch_device, attention_mask
)
# last 3 tokens will be in the last block and shoul have 0s in causal_mask
self.assertTrue(torch.all(causal_mask[:, -1, :, :, -3:] == 0))
Expand All @@ -539,8 +544,11 @@ def test_local_attn_probs(self):
mask_tokens = 3
attention_mask = torch.ones(batch_size, seq_length, device=torch_device, dtype=torch.long)
attention_mask[:, -mask_tokens:] = 0 # dont atten last mask_tokens
local_causal_mask = GPTNeoAttentionMixin.create_local_attention_mask(
batch_size, seq_length, model.config.window_size, torch_device, attention_mask
)

_, attn_probs = layer(hidden_states, attention_mask=attention_mask, output_attentions=True)
_, attn_probs = layer(hidden_states, attention_mask=local_causal_mask, output_attentions=True)

# the last 3 tokens will be in the last block, and should have 0 attn_probs
self.assertTrue(torch.all(attn_probs[:, -1, :, -mask_tokens:, -mask_tokens:] == 0))
Expand Down

0 comments on commit cfd2eaa

Please sign in to comment.