Skip to content

Commit

Permalink
fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Nov 17, 2020
1 parent 989de3f commit 7440a17
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 40 deletions.
Binary file modified .DS_Store
Binary file not shown.
9 changes: 3 additions & 6 deletions bert_seq2seq/model/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,9 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, input_ids=None, token_type_ids=None, position_ids=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]


input_shape = input_ids.size()

seq_length = input_shape[1]
device = input_ids.device
if position_ids is None:
Expand Down Expand Up @@ -446,7 +444,6 @@ def forward(
output_attentions=False
):


extended_attention_mask = (input_ids > 0).float()
# 注意力矩阵mask: [batch_size, 1, 1, seq_length]
extended_attention_mask = extended_attention_mask.unsqueeze(1).unsqueeze(2)
Expand Down
34 changes: 0 additions & 34 deletions test/test.py

This file was deleted.

0 comments on commit 7440a17

Please sign in to comment.