Skip to content

Commit

Permalink
BLEU score refactoring.
Browse files Browse the repository at this point in the history
math.fsum is used to make float sum stable.
  • Loading branch information
dimazest committed Jul 17, 2014
1 parent 9ed9345 commit b95e5a7
Showing 1 changed file with 13 additions and 24 deletions.
37 changes: 13 additions & 24 deletions nltk/align/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class BLEU(object):
Part 2 - brevity penalty
As the modified n-gram precision stil has the problem from the short
As the modified n-gram precision still has the problem from the short
length sentence, brevity penalty is used to modify the overall BLEU
score according to length.
Expand Down Expand Up @@ -89,34 +89,24 @@ class BLEU(object):

@staticmethod
def compute(candidate, references, weights):
candidate = [c.lower() for c in candidate]
references = [[r.lower() for r in reference] for reference in references]

candidate = list(map(lambda x: x.lower(), candidate))
references = list(map(lambda x: [c.lower() for c in x], references))

n = len(weights)
p_ns = (BLEU.modified_precision(candidate, references, i) for i, _ in enumerate(weights, start=1))
s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns))

bp = BLEU.brevity_penalty(candidate, references)

s = 0.0
i = 1
for weight in weights:
p_n = BLEU.modified_precision(candidate, references, i)
if p_n != 0:
s += weight * math.log(p_n)
i += 1

return bp * math.exp(s)

@staticmethod
def modified_precision(candidate, references, n):

candidate_ngrams = list(ngrams(candidate, n))
c_words = set(candidate_ngrams)

if len(candidate_ngrams) == 0:
if not c_words:
return 0

c_words = set(candidate_ngrams)

for word in c_words:
count_w = candidate_ngrams.count(word) + 1

Expand All @@ -125,23 +115,22 @@ def modified_precision(candidate, references, n):
reference_ngrams = list(ngrams(reference, n))

count = reference_ngrams.count(word) + 1
if count > count_max:
count_max = count
count_max = max(count, count_max)

# TODO: count_w == candidate_ngrams.count(c_words[-1]) + 1
# (even though c_words is a set, so there is no last element, it's the last element returned by the iterator.)
# Is it the desired behavior?
return min(count_w, count_max) / (len(candidate) + len(c_words))

@staticmethod
def brevity_penalty(candidate, references):
c = len(candidate)

lengthes_ref = map(lambda x: abs(len(x) - c), references)

r = min(*lengthes_ref)
r = min(abs(len(r) - c) for r in references)

if c > r:
return 1
else:
return math.exp(1 - r/c)
return math.exp(1 - r / c)

# run doctests
if __name__ == "__main__":
Expand Down

0 comments on commit b95e5a7

Please sign in to comment.