Skip to content

Commit

Permalink
removed alm hypothesis
Browse files Browse the repository at this point in the history
Signed-off-by: BAAI-OpenPlatform <[email protected]>
  • Loading branch information
BAAI-OpenPlatform committed Dec 19, 2022
1 parent 9387413 commit 1eb0382
Showing 1 changed file with 1 addition and 49 deletions.
50 changes: 1 addition & 49 deletions flagai/model/predictor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,54 +299,6 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
ret = self.worst_score >= cur_score
return ret


class BeamHypothesesALM:
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9

def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)

def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / (max(hyp.shape[-1], 1) ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, mems))
if len(self) > self.num_beams:
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)

def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
"""

if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret


def viterbi_decode(nodes, trans):
"""
Expand Down Expand Up @@ -566,7 +518,7 @@ def __init__(

self._is_init = False
self._beam_hyps = [
BeamHypothesesALM(
BeamHypotheses(
num_beams=self.num_beams,
max_length=self.max_length,
length_penalty=self.length_penalty,
Expand Down

0 comments on commit 1eb0382

Please sign in to comment.