10000 Merge pull request #1229 from alvations/develop · ExplodingCabbage/nltk@1d7966a · GitHub
[go: up one dir, main page]

Skip to content

Commit 1d7966a

Browse files
committed
Merge pull request nltk#1229 from alvations/develop
Added function to calculate Corpus-level BLEU and RIBES
2 parents c14c15a + c7c8dfb commit 1d7966a

File tree

3 files changed

+245
-48
lines changed

3 files changed

+245
-48
lines changed

nltk/translate/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from nltk.translate.ibm3 import IBMModel3
1919
from nltk.translate.ibm4 import IBMModel4
2020
from nltk.translate.ibm5 import IBMModel5
21-
from nltk.translate.bleu_score import bleu
22-
from nltk.translate.ribes_score import ribes
21+
from nltk.translate.bleu_score import sentence_bleu as bleu
22+
from nltk.translate.ribes_score import sentence_ribes as ribes
2323
from nltk.translate.metrics import alignment_error_rate
2424
from nltk.translate.stack_decoder import StackDecoder

nltk/translate/bleu_score.py

Lines changed: 188 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,27 @@
33
#
44
# Copyright (C) 2001-2015 NLTK Project
55
# Authors: Chin Yee Lee, Hengfeng Li, Ruxin Hou, Calvin Tanujaya Lim
6-
# Contributors: Dmitrijs Milajevs
6+
# Contributors: Dmitrijs Milajevs, Liling Tan
77
# URL: <http://nltk.org/>
88
# For license information, see LICENSE.TXT
99
"""BLEU score implementation."""
1010

1111
from __future__ import division
1212

1313
import math
14+
from fractions import Fraction
15+
from collections import Counter
1416

15-
from nltk.tokenize import word_tokenize
16-
from nltk.compat import Counter
1717
from nltk.util import ngrams
1818

1919

20-
def bleu(references, hypothesis, weights):
20+
def sentence_bleu(references, hypothesis, weights=[0.25, 0.25, 0.25, 0.25]):
2121
"""
2222
Calculate BLEU score (Bilingual Evaluation Understudy) from
2323
Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu. 2002.
2424
"BLEU: a method for automatic evaluation of machine translation."
2525
In Proceedings of ACL. http://www.aclweb.org/anthology/P02-1040.pdf
2626
27-
28-
>>> weights = [0.25, 0.25, 0.25, 0.25]
2927
>>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
3028
... 'ensures', 'that', 'the', 'military', 'always',
3129
... 'obeys', 'the', 'commands', 'of', 'the', 'party']
@@ -47,41 +45,148 @@ def bleu(references, hypothesis, weights):
4745
... 'army', 'always', 'to', 'heed', 'the', 'directions',
4846
... 'of', 'the', 'party']
4947
50-
>>> bleu([reference1, reference2, reference3], hypothesis1, weights)
48+
>>> sentence_bleu([reference1, reference2, reference3], hypothesis1)
5149
0.5045666840058485
5250
53-
>>> bleu([reference1, reference2, reference3], hypothesis2, weights)
51+
>>> sentence_bleu([reference1, reference2, reference3], hypothesis2)
5452
0
5553
54+
The default BLEU calculates a score for up to 4grams using uniform
55+
weights. To evaluate your translations with higher/lower order ngrams,
56+
use customized weights. E.g. when accounting for up to 6grams with uniform
57+
weights:
58+
59+
>>> weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
60+
>>> sentence_bleu([reference1, reference2, reference3], hypothesis1, weights)
61+
0.45838627164939455
62+
5663
:param references: reference sentences
5764
:type references: list(list(str))
5865
:param hypothesis: a hypothesis sentence
5966
:type hypothesis: list(str)
6067
:param weights: weights for unigrams, bigrams, trigrams and so on
6168
:type weights: list(float)
69+
:return: The sentence-level BLEU score.
70+
:rtype: float
6271
"""
63-
p_ns = (
64-
_modified_precision(references, hypothesis, i)
65-
for i, _ in enumerate(weights, start=1)
66-
)
72+
# Calculates the modified precision *p_n* for each order of ngram.
73+
p_ns = []
74+
for i, _ in enumerate(weights, start=1):
75+
p_n = float(_modified_precision(references, hypothesis, i))
76+
p_ns.append(p_n)
6777

6878
try:
79+
# Calculates the overall modified precision for all ngrams.
80+
# By taking the product of the weights and the respective *p_n*
6981
s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns))
7082
except ValueError:
7183
# some p_ns is 0
7284
return 0
7385

74-
bp = _brevity_penalty(references, hypothesis)
86+
# Calculates the brevity penalty.
87+
# *hyp_len* is referred to as *c* in Papineni et. al. (2002)
88+
hyp_len = len(hypothesis)
89+
# *closest_ref_len* is referred to as *r* variable in Papineni et. al. (2002)
90+
closest_ref_len = _closest_ref_length(references, hyp_len)
91+
bp = _brevity_penalty(closest_ref_len, hyp_len)
7592
return bp * math.exp(s)
7693

7794

95+
def corpus_bleu(list_of_references, hypotheses, weights=[0.25, 0.25, 0.25, 0.25]):
96+
"""
97+
Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all
98+
the hypotheses and their respective references.
99+
100+
Instead of averaging the sentence level BLEU scores (i.e. marco-average
101+
precision), the original BLEU metric (Papineni et al. 2002) accounts for
102+
the micro-average precision (i.e. summing the numerators and denominators
103+
for each hypothesis-reference(s) pairs before the division).
104+
105+
>>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
106+
... 'ensures', 'that', 'the', 'military', 'always',
107+
... 'obeys', 'the', 'commands', 'of', 'the', 'party']
108+
>>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
109+
... 'ensures', 'that', 'the', 'military', 'will', 'forever',
110+
... 'heed', 'Party', 'commands']
111+
>>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which',
112+
... 'guarantees', 'the', 'military', 'forces', 'always',
113+
... 'being', 'under', 'the', 'command', 'of', 'the', 'Party']
114+
>>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
115+
... 'army', 'always', 'to', 'heed', 'the', 'directions',
116+
... 'of', 'the', 'party']
117+
118+
>>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was',
119+
... 'interested', 'in', 'world', 'history']
120+
>>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history',
121+
... 'because', 'he', 'read', 'the', 'book']
122+
123+
>>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]]
124+
>>> hypotheses = [hyp1, hyp2]
125+
>>> corpus_bleu(list_of_references, hypotheses)
126+
0.5520516129306314
127+
128+
The example below show that corpus_bleu() is different from averaging
129+
sentence_bleu() for hypotheses
130+
131+
>>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1)
132+
>>> score2 = sentence_bleu([ref2a], hyp2)
133+
>>> (score1 + score2) / 2
134+
0.6223247442490669
135+
136+
:param references: a corpus of lists of reference sentences, w.r.t. hypotheses
137+
:type references: list(list(list(str)))
138+
:param hypotheses: a list of hypothesis sentences
139+
:type hypotheses: list(list(str))
140+
:param weights: weights for unigrams, bigrams, trigrams and so on
141+
:type weights: list(float)
142+
:return: The corpus-level BLEU score.
143+
:rtype: float
144+
"""
145+
p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches.
146+
p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref.
147+
hyp_lengths, ref_lengths = 0, 0
148+
149+
assert len(list_of_references) == len(hypotheses), "The number of hypotheses and their reference(s) should be the same"
150+
151+
# Iterate through each hypothesis and their corresponding references.
152+
for references, hypothesis in zip(list_of_references, hypotheses):
153+
# For each order of ngram, calculate the numerator and
154+
# denominator for the corpus-level modified precision.
155+
for i, _ in enumerate(weights, start=1):
156+
p_i = _modified_precision(references, hypothesis, i)
157+
p_numerators[i] += p_i.numerator
158+
p_denominators[i] += p_i.denominator
159+
160+
# Calculate the hypothesis length and the closest reference length.
161+
# Adds them to the corpus-level hypothesis and reference counts.
162+
hyp_len = len(hypothesis)
163+
hyp_lengths += hyp_len
164+
ref_lengths += _closest_ref_length(references, hyp_len)
165+
166+
# Calculate corpus-level brevity penalty.
167+
bp = _brevity_penalty(ref_lengths, hyp_lengths)
168+
169+
# Calculate corpus-level modified precision.
170+
p_n = []
171+
for i, w in enumerate(weights, start=1):
172+
pn = p_numerators[i] / p_denominators[i]
173+
p_n.append(w* math.log(pn))
174+
175+
return bp * math.exp(math.fsum(p_n))
176+
177+
78178
def _modified_precision(references, hypothesis, n):
79179
"""
80180
Calculate modified ngram precision.
81181
82182
The normal precision method may lead to some wrong translations with
83183
high-precision, e.g., the translation, in which a word of reference
84-
repeats several times, has very high precision.
184+
repeats several times, has very high precision.
185+
186+
This function only returns the Fraction object that contains the numerator
187+
and denominator necessary to calculate the corpus-level precision.
188+
To calculate the modified precision for a single pair of hypothesis and
189+
references, cast the Fraction object into a float.
85190
86191
The famous "the the the ... " example shows that you can get BLEU precision
87192
by duplicating high frequency words.
@@ -90,7 +195,7 @@ def _modified_precision(references, hypothesis, n):
90195
>>> reference2 = 'there is a cat on the mat'.split()
91196
>>> hypothesis1 = 'the the the the the the the'.split()
92197
>>> references = [reference1, reference2]
93-
>>> _modified_precision(references, hypothesis1, n=1)
198+
>>> float(_modified_precision(references, hypothesis1, n=1))
94199
0.2857142857142857
95200
96201
In the modified n-gram precision, a reference word will be considered
@@ -108,9 +213,9 @@ def _modified_precision(references, hypothesis, n):
108213
... 'of', 'the', 'party']
109214
>>> hypothesis = 'of the'.split()
110215
>>> references = [reference1, reference2, reference3]
111-
>>> _modified_precision(references, hypothesis, n=1)
216+
>>> float(_modified_precision(references, hypothesis, n=1))
112217
1.0
113-
>>> _modified_precision(references, hypothesis, n=2)
218+
>>> float(_modified_precision(references, hypothesis, n=2))
114219
1.0
115220
116221
An example of a normal machine translation hypothesis:
@@ -136,39 +241,64 @@ def _modified_precision(references, hypothesis, n):
136241
... 'army', 'always', 'to', 'heed', 'the', 'directions',
137242
... 'of', 'the', 'party']
138243
>>> references = [reference1, reference2, reference3]
139-
>>> _modified_precision(references, hypothesis1, n=1)
244+
>>> float(_modified_precision(references, hypothesis1, n=1))
140245
F438 0.9444444444444444
141-
>>> _modified_precision(references, hypothesis2, n=1)
246+
>>> float(_modified_precision(references, hypothesis2, n=1))
142247
0.5714285714285714
143-
>>> _modified_precision(references, hypothesis1, n=2)
248+
>>> float(_modified_precision(references, hypothesis1, n=2))
144249
0.5882352941176471
145-
>>> _modified_precision(references, hypothesis2, n=2)
250+
>>> float(_modified_precision(references, hypothesis2, n=2))
146251
0.07692307692307693
147-
252+
253+
148254
:param references: A list of reference translations.
149255
:type references: list(list(str))
150256
:param hypothesis: A hypothesis translation.
151257
:type hypothesis: list(str)
152258
:param n: The ngram order.
153259
:type n: int
260+
:return: BLEU's modified precision for the nth order ngram.
261+
:rtype: Fraction
154262
"""
155263
counts = Counter(ngrams(hypothesis, n))
156264

157265
if not counts:
158-
return 0
266+
return Fraction(0)
159267

160268
max_counts = {}
161269
for reference in references:
162270
reference_counts = Counter(ngrams(reference, n))
163271
for ngram in counts:
164272
max_counts[ngram] = max(max_counts.get(ngram, 0), reference_counts[ngram])
165273

166-
clipped_counts = dict((ngram, min(count, max_counts[ngram])) for ngram, count in counts.items())
167-
168-
return sum(clipped_counts.values()) / sum(counts.values())
274+
clipped_counts = dict((ngram, min(count, max_counts[ngram]))
275+
for ngram, count in counts.items())
276+
277+
numerator = sum(clipped_counts.values())
278+
denominator = sum(counts.values())
279+
280+
return Fraction(numerator, denominator)
281+
169282

283+
def _closest_ref_length(references, hyp_len):
284+
"""
285+
This function finds the reference that is the closest length to the
286+
hypothesis. The closest reference length is referred to as *r* variable
287+
from the brevity penalty formula in Papineni et. al. (2002)
288+
289+
:param references: A list of reference translations.
290+
:type references: list(list(str))
291+
:param hypothesis: The length of the hypothesis.
292+
:type hypothesis: int
293+
:return: The length of the reference that's closest to the hypothesis.
294+
:rtype: int
295+
"""
296+
ref_lens = (len(reference) for reference in references)
297+
closest_ref_len = min(ref_lens, key=lambda ref_len:
298+
(abs(ref_len - hyp_len), ref_len))
299+
return closest_ref_len
170300

171-
def _brevity_penalty(references, hypothesis):
301+
def _brevity_penalty(closest_ref_len, hyp_len):
172302
"""
173303
Calculate brevity penalty.
174304
@@ -184,15 +314,19 @@ def _brevity_penalty(references, hypothesis):
184314
>>> reference3 = list('aaaaaaaaaaaaaaaaa') # i.e. ['a'] * 17
185315
>>> hypothesis = list('aaaaaaaaaaaa') # i.e. ['a'] * 12
186316
>>> references = [reference1, reference2, reference3]
187-
>>> _brevity_penalty(references, hypothesis)
317+
>>> hyp_len = len(hypothesis)
318+
>>> closest_ref_len = _closest_ref_length(references, hyp_len)
319+
>>> _brevity_penalty(closest_ref_len, hyp_len)
188320
1.0
189321
190322
In case a hypothesis translation is shorter than the references, penalty is
191323
applied.
192324
193325
>>> references = [['a'] * 28, ['a'] * 28]
194326
>>> hypothesis = ['a'] * 12
195-
>>> _brevity_penalty(references, hypothesis)
327+
>>> hyp_len = len(hypothesis)
328+
>>> closest_ref_len = _closest_ref_length(references, hyp_len)
329+
>>> _brevity_penalty(closest_ref_len, hyp_len)
196330
0.2635971381157267
197331
198332
The length of the closest reference is used to compute the penalty. If the
@@ -202,7 +336,9 @@ def _brevity_penalty(references, hypothesis):
202336
203337
>>> references = [['a'] * 13, ['a'] * 2]
204338
>>> hypothesis = ['a'] * 12
205-
>>> _brevity_penalty(references, hypothesis)
339+
>>> hyp_len = len(hypothesis)
340+
>>> closest_ref_len = _closest_ref_length(references, hyp_len)
341+
>>> _brevity_penalty(closest_ref_len, hyp_len)
206342
0.9200444146293233
207343
208344
The brevity penalty doesn't depend on reference order. More importantly,
@@ -211,34 +347,42 @@ def _brevity_penalty(references, hypothesis):
211347
212348
>>> references = [['a'] * 13, ['a'] * 11]
213349
>>> hypothesis = ['a'] * 12
214-
>>> bp1 = _brevity_penalty(references, hypothesis)
215-
>>> bp2 = _brevity_penalty(reversed(references),hypothesis)
350+
>>> hyp_len = len(hypothesis)
351+
>>> closest_ref_len = _closest_ref_length(references, hyp_len)
352+
>>> bp1 = _brevity_penalty(closest_ref_len, hyp_len)
353+
>>> hyp_len = len(hypothesis)
354+
>>> closest_ref_len = _closest_ref_length(reversed(references), hyp_len)
355+
>>> bp2 = _brevity_penalty(closest_ref_len, hyp_len)
216356
>>> bp1 == bp2 == 1
217357
True
218358
219359
A test example from mteval-v13a.pl (starting from the line 705):
220360
221361
>>> references = [['a'] * 11, ['a'] * 8]
222362
>>> hypothesis = ['a'] * 7
223-
>>> _brevity_penalty(references, hypothesis)
363+
>>> hyp_len = len(hypothesis)
364+
>>> closest_ref_len = _closest_ref_length(references, hyp_len)
365+
>>> _brevity_penalty(closest_ref_len, hyp_len)
224366
0.8668778997501817
225367
226368
>>> references = [['a'] * 11, ['a'] * 8, ['a'] * 6, ['a'] * 7]
227369
>>> hypothesis = ['a'] * 7
228-
>>> _brevity_penalty(references, hypothesis)
370+
>>> hyp_len = len(hypothesis)
371+
>>> closest_ref_len = _closest_ref_length(references, hyp_len)
372+
>>> _brevity_penalty(closest_ref_len, hyp_len)
229373
1.0
230374
231-
:param references: A list of reference translations.
232-
:type references: list(list(str))
233-
:param hypothesis: A hypothesis translation.
234-
:type hypothesis: list(str)
375+
:param hyp_len: The length of the hypothesis for a single sentence OR the
376+
sum of all the hypotheses' lengths for a corpus
377+
:type hyp_len: int
378+
:param closest_ref_len: The length of the closest reference for a single
379+
hypothesis OR the sum of all the closest references for every hypotheses.
380+
:type closest_reference_len: int
381+
:return: BLEU's brevity penalty.
382+
:rtype: float
235383
"""
236-
c = len(hypothesis)
237-
ref_lens = (len(reference) for reference in references)
238-
r = min(ref_lens, key=lambda ref_len: (abs(ref_len - c), ref_len))
239-
240-
if c > r:
384+
if hyp_len > closest_ref_len:
241385
return 1
242386
else:
243-
return math.exp(1 - r / c)
387+
return math.exp(1 - closest_ref_len / hyp_len)
244388

0 commit comments

Comments
 (0)
0