Skip to content

Commit

Permalink
token_type_id_rev (fastnlp#329)
Browse files Browse the repository at this point in the history
  • Loading branch information
stratoes authored Oct 19, 2020
1 parent 08cf1a5 commit 3270b8b
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions fastNLP/embeddings/bert_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,7 @@ def forward(self, word_pieces, token_type_ids=None):
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
sep_mask_cumsum = sep_mask.long().flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
token_type_ids = sep_mask_cumsum.fmod(2)
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
token_type_ids = token_type_ids.eq(0).long()
token_type_ids = token_type_ids[:, :1].__xor__(token_type_ids) # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0

word_pieces = self.drop_word(word_pieces)
outputs = self.model(word_pieces, token_type_ids)
Expand Down Expand Up @@ -465,8 +464,7 @@ def forward(self, words):
sep_mask = word_pieces.eq(self._sep_index).long() # batch_size x max_len
sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
token_type_ids = sep_mask_cumsum.fmod(2)
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
token_type_ids = token_type_ids.eq(0).long()
token_type_ids = token_type_ids[:, :1].__xor__(token_type_ids) # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
else:
token_type_ids = torch.zeros_like(word_pieces)
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
Expand Down

0 comments on commit 3270b8b

Please sign in to comment.