Skip to content

Commit

Permalink
suppress generating non-timestamp tokens at the beginning (openai#532)
Browse files Browse the repository at this point in the history
  • Loading branch information
jumon authored Nov 15, 2022
1 parent 9f70a35 commit 76148a5
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,14 @@ def apply(self, logits: Tensor, tokens: Tensor):
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf

# apply the `max_initial_timestamp` option
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits[:, last_allowed + 1 :] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf

# apply the `max_initial_timestamp` option
if self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits[:, last_allowed + 1 :] = -np.inf

# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
Expand Down

0 comments on commit 76148a5

Please sign in to comment.