8000 bpo-42944 Fix Random.sample when counts is not None (GH-24235) · python/cpython@f7b5bac · GitHub
[go: up one dir, main page]

Skip to content

Commit f7b5bac

Browse files
authored
bpo-42944 Fix Random.sample when counts is not None (GH-24235)
1 parent 314b878 commit f7b5bac

File tree

3 files changed

+29
-28
lines changed

3 files changed

+29
-28
lines changed

Lib/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def sample(self, population, k, *, counts=None):
479479
raise TypeError('Counts must be integers')
480480
if total <= 0:
481481
raise ValueError('Total of counts must be greater than zero')
482-
selections = sample(range(total), k=k)
482+
selections = self.sample(range(total), k=k)
483483
bisect = _bisect
484484
return [population[bisect(cum_counts, s)] for s in selections]
485485
randbelow = self._randbelow

Lib/test/test_random.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -223,33 +223,6 @@ def test_sample_with_counts(self):
223223
with self.assertRaises(ValueError):
224224
sample(['red', 'green', 'blue'], counts=[1, 2, 3, 4], k=2) # too many counts
225225

226-
def test_sample_counts_equivalence(self):
227-
# Test the documented strong equivalence to a sample with repeated elements.
228-
# We run this test on random.Random() which makes deterministic selections
229-
# for a given seed value.
230-
sample = random.sample
231-
seed = random.seed
232-
233-
colors = ['red', 'green', 'blue', 'orange', 'black', 'amber']
234-
counts = [500, 200, 20, 10, 5, 1 ]
235-
k = 700
236-
seed(8675309)
237-
s1 = sample(colors, counts=counts, k=k)
238-
seed(8675309)
239-
expanded = [color for (color, count) in zip(colors, counts) for i in range(count)]
240-
self.assertEqual(len(expanded), sum(counts))
241-
s2 = sample(expanded, k=k)
242-
self.assertEqual(s1, s2)
243-
244-
pop = 'abcdefghi'
245-
counts = [10, 9, 8, 7, 6, 5, 4, 3, 2]
246-
seed(8675309)
247-
s1 = ''.join(sample(pop, counts=counts, k=30))
248-
expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)])
249-
seed(8675309)
250-
s2 = ''.join(sample(expanded, k=30))
251-
self.assertEqual(s1, s2)
252-
253226
def test_choices(self):
254227
choices = self.gen.choices
255228
data = ['red', 'green', 'blue', 'yellow']
@@ -957,6 +930,33 @@ def test_randbytes_getrandbits(self):
957930
self.assertEqual(self.gen.randbytes(n),
958931
gen2.getrandbits(n * 8).to_bytes(n, 'little'))
959932

933+
def test_sample_counts_equivalence(self):
934+
# Test the documented strong equivalence to a sample with repeated elements.
935+
# We run this test on random.Random() which makes deterministic selections
936+
# for a given seed value.
937+
sample = self.gen.sample
938+
seed = self.gen.seed
939+
940+
colors = ['red', 'green', 'blue', 'orange', 'black', 'amber']
941+
counts = [500, 200, 20, 10, 5, 1 ]
942+
k = 700
943+
seed(8675309)
944+
s1 = sample(colors, counts=counts, k=k)
945+
seed(8675309)
946+
expanded = [color for (color, count) in zip(colors, counts) for i in range(count)]
947+
self.assertEqual(len(expanded), sum(counts))
948+
s2 = sample(expanded, k=k)
949+
self.assertEqual(s1, s2)
950+
951+
pop = 'abcdefghi'
952+
counts = [10, 9, 8, 7, 6, 5, 4, 3, 2]
953+
seed(8675309)
954+
s1 = ''.join(sample(pop, counts=counts, k=30))
955+
expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)])
956+
seed(8675309)
957+
s2 = ''.join(sample(expanded, k=30))
958+
self.assertEqual(s1, s2)
959+
960960

961961
def gamma(z, sqrt2pi=(2.0*pi)**0.5):
962962
# Reflection to right half of complex plane
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix ``random.Random.sample`` when ``counts`` argument is not ``None``.

0 commit comments

Comments
 (0)
0