Skip to content

Commit

Permalink
[Generate] Remove outdated code (huggingface#11331)
Browse files Browse the repository at this point in the history
* remove update function

* update

* refactor more

* refactor
  • Loading branch information
patrickvonplaten authored Apr 20, 2021
1 parent bfd83c1 commit f464f10
Showing 1 changed file with 20 additions and 60 deletions.
80 changes: 20 additions & 60 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,31 +483,6 @@ def _expand_inputs_for_generation(
model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs

@staticmethod
def _init_sequence_length_for_generation(
input_ids: torch.LongTensor, max_length: int
) -> Tuple[torch.Tensor, torch.Tensor, int]:
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
sequence_lengths = input_ids.new(input_ids.shape[0]).fill_(max_length)

cur_len = input_ids.shape[-1]
return sequence_lengths, unfinished_sequences, cur_len

@staticmethod
def _update_seq_length_for_generation(
sequence_lengths: torch.LongTensor,
unfinished_sequences: torch.LongTensor,
cur_len: int,
is_eos_in_next_token: torch.BoolTensor,
) -> Tuple[torch.LongTensor, torch.LongTensor]:
# check if sentence is not finished yet
is_sent_unfinished = unfinished_sequences.mul(is_eos_in_next_token.long()).bool()

# update sentence length
sequence_lengths = sequence_lengths.masked_fill(is_sent_unfinished, cur_len)
unfinished_sequences = unfinished_sequences.mul((~is_eos_in_next_token).long())
return sequence_lengths, unfinished_sequences

@staticmethod
def _update_model_kwargs_for_generation(
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
Expand Down Expand Up @@ -1271,10 +1246,9 @@ def greedy_search(
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)

# init sequence length tensors
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
input_ids, max_length
)
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
cur_len = input_ids.shape[-1]

this_peer_finished = False # used by synced_gpus only
while cur_len < max_length:
Expand Down Expand Up @@ -1330,29 +1304,23 @@ def greedy_search(
# argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1)

# add code that transforms next_tokens to tokens_to_add
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

# add token and increase length by one
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

# update sequence length
if eos_token_id is not None:
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
)

# update model kwargs
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)

# increase cur_len
cur_len = cur_len + 1

# stop when there is a </s> in each sentence, or if we exceed the maximum length
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
Expand Down Expand Up @@ -1511,10 +1479,9 @@ def sample(
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)

# init sequence length tensors
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
input_ids, max_length
)
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
cur_len = input_ids.shape[-1]

this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
Expand Down Expand Up @@ -1571,32 +1538,25 @@ def sample(

# sample
probs = F.softmax(next_token_scores, dim=-1)

next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

# add code that transforms next_tokens to tokens_to_add
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

# add token and increase length by one
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

# update sequence length
if eos_token_id is not None:
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
)

# update model kwargs
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)

# increase cur_len
cur_len = cur_len + 1

# stop when there is a </s> in each sentence, or if we exceed the maximum length
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
Expand Down

0 comments on commit f464f10

Please sign in to comment.