8000 bpo-40541: Add optional *counts* parameter to random.sample() (GH-19970) · python/cpython@81a5fc3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 81a5fc3

Browse files
authored
bpo-40541: Add optional *counts* parameter to random.sample() (GH-19970)
1 parent 2effef7 commit 81a5fc3

File tree

4 files changed

+116
-13
lines changed

4 files changed

+116
-13
lines changed

Doc/library/random.rst

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ Functions for sequences
217217
The optional parameter *random*.
218218

219219

220-
.. function:: sample(population, k)
220+
.. function:: sample(population, k, *, counts=None)
221221

222222
Return a *k* length list of unique elements chosen from the population sequence
223223
or set. Used for random sampling without replacement.
@@ -231,13 +231,21 @@ Functions for sequences
231231
Members of the population need not be :term:`hashable` or unique. If the population
232232
contains repeats, then each occurrence is a possible selection in the sample.
233233

234+
Repeated elements can be specified one at a time or with the optional
235+
keyword-only *counts* parameter. For example, ``sample(['red', 'blue'],
236+
counts=[4, 2], k=5)`` is equivalent to ``sample(['red', 'red', 'red', 'red',
237+
'blue', 'blue'], k=5)``.
238+
234239
To choose a sample from a range of integers, use a :func:`range` object as an
235240
argument. This is especially fast and space efficient for sampling from a large
236241
population: ``sample(range(10000000), k=60)``.
237242

238243
If the sample size is larger than the population size, a :exc:`ValueError`
239244
is raised.
240245

246+
.. versionchanged:: 3.9
247+
Added the *counts* parameter.
248+
241249
.. deprecated:: 3.9
242250
In the future, the *population* must be a sequence. Instances of
243251
:class:`set` are no longer supported. The set must first be converted
@@ -420,12 +428,11 @@ Simulations::
420428
>>> choices(['red', 'black', 'green'], [18, 18, 2], k=6)
421429
['red', 'green', 'black', 'black', 'red', 'black']
422430

423-
>>> # Deal 20 cards without replacement from a deck of 52 playing cards
424-
>>> # and determine the proportion of cards with a ten-value
425-
>>> # (a ten, jack, queen, or king).
426-
>>> deck = collections.Counter(tens=16, low_cards=36)
427-
>>> seen = sample(list(deck.elements()), k=20)
428-
>>> seen.count('tens') / 20
431+
>>> # Deal 20 cards without replacement from a deck
432+
>>> # of 52 playing cards, and determine the proportion of cards
433+
>>> # with a ten-value: ten, jack, queen, or king.
434+
>>> dealt = sample(['tens', 'low cards'], counts=[16, 36], k=20)
435+
>>> dealt.count('tens') / 20
429436
0.15
430437

431438
>>> # Estimate the probability of getting 5 or more heads from 7 spins

Lib/random.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def shuffle(self, x, random=None):
331331
j = _int(random() * (i+1))
332332
x[i], x[j] = x[j], x[i]
333333

334-
def sample(self, population, k):
334+
def sample(self, population, k, *, counts=None):
335335
"""Chooses k unique random elements from a population sequence or set.
336336
337337
Returns a new list containing elements from the population while
@@ -344,9 +344,21 @@ def sample(self, population, k):
344344
population contains repeats, then each occurrence is a possible
345345
selection in the sample.
346346
347-
To choose a sample in a range of integers, use range as an argument.
348-
This is especially fast and space efficient for sampling from a
349-
large population: sample(range(10000000), 60)
347+
Repeated elements can be specified one at a time or with the optional
348+
counts parameter. For example:
349+
350+
sample(['red', 'blue'], counts=[4, 2], k=5)
351+
352+
is equivalent to:
353+
354+
sample(['red', 'red', 'red', 'red', 'blue', 'blue'], k=5)
355+
356+
To choose a sample from a range of integers, use range() for the
357+
population argument. This is especially fast and space efficient
358+
for sampling from a large population:
359+
360+
sample(range(10000000), 60)
361+
350362
"""
351363

352364
# Sampling without replacement entails tracking either potential
@@ -379,8 +391,20 @@ def sample(self, population, k):
379391
population = tuple(population)
380392
if not isinstance(population, _Sequence):
381393
raise TypeError("Population must be a sequence. For dicts or sets, use sorted(d).")
382-
randbelow = self._randbelow
383394
n = len(population)
395+
if counts is not None:
396+ B3F0
cum_counts = list(_accumulate(counts))
397+
if len(cum_counts) != n:
398+
raise ValueError('The number of counts does not match the population')
399+
total = cum_counts.pop()
400+
if not isinstance(total, int):
401+
raise TypeError('Counts must be integers')
402+
if total <= 0:
403+
raise ValueError('Total of counts must be greater than zero')
404+
selections = sample(range(total), k=k)
405+
bisect = _bisect
406+
return [population[bisect(cum_counts, s)] for s in selections]
407+
randbelow = self._randbelow
384408
if not 0 <= k <= n:
385409
raise ValueError("Sample larger than population or is negative")
386410
result = [None] * k

Lib/test/test_random.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from math import log, exp, pi, fsum, sin, factorial
1010
from test import support
1111
from fractions import Fraction
12-
12+
from collections import Counter
1313

1414
class TestBasicOps:
1515
# Superclass with tests common to all generators.
@@ -161,6 +161,77 @@ def test_sample_on_sets(self):
161161
population = {10, 20, 30, 40, 50, 60, 70}
162< F438 /td>162
self.gen.sample(population, k=5)
163163

164+
def test_sample_with_counts(self):
165+
sample = self.gen.sample
166+
167+
# General case
168+
colors = ['red', 'green', 'blue', 'orange', 'black', 'brown', 'amber']
169+
counts = [500, 200, 20, 10, 5, 0, 1 ]
170+
k = 700
171+
summary = Counter(sample(colors, counts=counts, k=k))
172+
self.assertEqual(sum(summary.values()), k)
173+
for color, weight in zip(colors, counts):
174+
self.assertLessEqual(summary[color], weight)
175+
self.assertNotIn('brown', summary)
176+
177+
# Case that exhausts the population
178+
k = sum(counts)
179+
summary = Counter(sample(colors, counts=counts, k=k))
180+
self.assertEqual(sum(summary.values()), k)
181+
for color, weight in zip(colors, counts):
182+
self.assertLessEqual(summary[color], weight)
183+
self.assertNotIn('brown', summary)
184+
185+
# Case with population size of 1
186+
summary = Counter(sample(['x'], counts=[10], k=8))
187+
self.assertEqual(summary, Counter(x=8))
188+
189+
# Case with all counts equal.
190+
nc = len(colors)
191+
summary = Counter(sample(colors, counts=[10]*nc, k=10*nc))
192+
self.assertEqual(summary, Counter(10*colors))
193+
194+
# Test error handling
195+
with self.assertRaises(TypeError):
196+
sample(['red', 'green', 'blue'], counts=10, k=10) # counts not iterable
197+
with self.assertRaises(ValueError):
198+
sample(['red', 'green', 'blue'], counts=[-3, -7, -8], k=2) # counts are negative
199+
with self.assertRaises(ValueError):
200+
sample(['red', 'green', 'blue'], counts=[0, 0, 0], k=2) # counts are zero
201+
with self.assertRaises(ValueError):
202+
sample(['red', 'green'], counts=[10, 10], k=21) # population too small
203+
with self.assertRaises(ValueError):
204+
sample(['red', 'green', 'blue'], counts=[1, 2], k=2) # too few counts
205+
with self.assertRaises(ValueError):
206+
sample(['red', 'green', 'blue'], counts=[1, 2, 3, 4], k=2) # too many counts
207+
208+
def test_sample_counts_equivalence(self):
209+
# Test the documented strong equivalence to a sample with repeated elements.
210+
# We run this test on random.Random() which makes deterministic selections
211+
# for a given seed value.
212+
sample = random.sample
213+
seed = random.seed
214+
215+
colors = ['red', 'green', 'blue', 'orange', 'black', 'amber']
216+
counts = [500, 200, 20, 10, 5, 1 ]
217+
k = 700
218+
seed(8675309)
219+
s1 = sample(colors, counts=counts, k=k)
220+
seed(8675309)
221+
expanded = [color for (color, count) in zip(colors, counts) for i in range(count)]
222+
self.assertEqual(len(expanded), sum(counts))
223+
s2 = sample(expanded, k=k)
224+
self.assertEqual(s1, s2)
225+
226+
pop = 'abcdefghi'
227+
counts = [10, 9, 8, 7, 6, 5, 4, 3, 2]
228+
seed(8675309)
229+
s1 = ''.join(sample(pop, counts=counts, k=30))
230+
expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)])
231+
seed(8675309)
232+
s2 = ''.join(sample(expanded, k=30))
233+
self.assertEqual(s1, s2)
234+
164235
def test_choices(self):
165236
choices = self.gen.choices
166237
data = ['red', 'green', 'blue', 'yellow']
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added an optional *counts* parameter to random.sample().

0 commit comments

Comments
 (0)
0