8000 [External] [stdlib] Introduce `random.shuffle` for `List` (#51416) · modular/modular@eb27573 · GitHub
[go: up one dir, main page]

Skip to content

Commit eb27573

Browse files
jjvrawmodularbot
authored andcommitted
[External] [stdlib] Introduce random.shuffle for List (#51416)
[External] [stdlib] Introduce `random.shuffle` for `List` Introduce `random.shuffle` for `List`. Implementation follows the Fisher-Yates shuffle. Co-authored-by: Joshua James Venter <venter.joshua@gmail.com> Closes #3327 MODULAR_ORIG_COMMIT_REV_ID: 911eadee8d34911f653e8c69f85fcd89f5cae344
1 parent 6238d95 commit eb27573

File tree

4 files changed

+126
-3
lines changed

4 files changed

+126
-3
lines changed

docs/changelog.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,18 @@ what we publish.
110110
- The `rebind` standard library function now works with memory-only types in
111111
addition to `@register_passable("trivial")` ones, without requiring a copy.
112112

113+
- Introduce `random.shuffle` for `List`.
114+
([PR #3327](https://github.com/modularml/mojo/pull/3327) by [@jjvraw](https://github.com/jjvraw))
115+
116+
Example:
117+
118+
```mojo
119+
from random import shuffle
120+
121+
var l = List[Int](1, 2, 3, 4, 5)
122+
shuffle(l)
123+
```
124+
113125
- The `Dict.__getitem__` method now returns a reference instead of a copy of
114126
the value (or raises). This improves the performance of common code that
115127
uses `Dict` by allowing borrows from the `Dict` elements.

stdlib/src/random/__init__.mojo

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ from .random import (
2121
random_si64,
2222
random_ui64,
2323
seed,
24+
shuffle,
2425
)

stdlib/src/random/random.mojo

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ from random import seed
2222
from sys import bitwidthof, external_call
2323
from sys.ffi import OpaquePointer
2424
from time import perf_counter_ns
25-
from collections import Optional
26-
25+
from collections import Optional, List
2726
from memory import UnsafePointer
2827
from math import floor
2928
import math
@@ -222,3 +221,19 @@ fn randn[
222221
for i in range(size):
223222
ptr[i] = randn_float64(mean, variance).cast[type]()
224223
return
224+
225+
226+
fn shuffle[T: CollectionElement, //](inout list: List[T]):
227+
"""Shuffles the elements of the list randomly.
228+
229+
Performs an in-place Fisher-Yates shuffle on the provided list.
230+
231+
Args:
232+
list: The list to modify.
233+
234+
Parameters:
235+
T: The type of element in the List.
236+
"""
237+
for i in reversed(range(len(list))):
238+
var j = int(random_ui64(0, i))
239+
list.swap_elements(i, j)

stdlib/test/random/test_random.mojo

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,14 @@
1212
# ===----------------------------------------------------------------------=== #
1313
# RUN: %mojo %s
1414

15-
from random import randn_float64, random_float64, random_si64, random_ui64, seed
15+
from random import (
16+
randn_float64,
17+
random_float64,
18+
random_si64,
19+
random_ui64,
20+
seed,
21+
shuffle,
22+
)
1623

1724
from testing import assert_equal, assert_true
1825

@@ -73,6 +80,94 @@ def test_seed():
7380
assert_equal(some_unsigned_integer, random_ui64(0, 255))
7481

7582

83+
def test_shuffle():
84+
# TODO: Clean up with list comprehension when possible.
85+
86+
# Property tests
87+
alias L_i = List[Int]
88+
alias L_s = List[String]
89+
var a = L_i(1, 2, 3, 4)
90+
var b = L_i(1, 2, 3, 4)
91+
var c = L_s("Random", "shuffle", "in", "Mojo")
92+
var d = L_s("Random", "shuffle", "in", "Mojo")
93+
94+
shuffle(b)
95+
assert_equal(len(a), len(b))
96+
assert_true(a != b)
97+
for i in range(len(b)):
98+
assert_true(b[i] in a)
99+
100+
shuffle(d)
101+
assert_equal(len(c), len(d))
102+
assert_true(c != d)
103+
for i in range(len(d)):
104+
assert_true(d[i] in c)
105+
106+
var e = L_i(21)
107+
shuffle(e)
108+
assert_true(e == L_i(21))
109+
var f = L_s("Mojo")
110+
shuffle(f)
111+
assert_true(f == L_s("Mojo"))
112+
113+
alias L_l = List[List[Int]]
114+
var g = L_l()
115+
var h = L_l()
116+
for i in range(10):
117+
g.append(L_i(i, i + 1, i + 3))
118+
h.append(L_i(i, i + 1, i + 3))
119+
shuffle(g)
120+
# TODO: Uncomment when possible
121+
# assert_true(g != h)
122+
assert_equal(len(g), len(h))
123+
for i in range(10):
124+
# Currently, the below does not compile.
125+
# assert_true(g.__contains__(L_i(i, i + 1, i + 3)))
126+
var target: List[Int] = L_i(i, i + 1, i + 3)
127+
var found = False
128+
for j in range(len(g)):
129+
if g[j] == target:
130+
found = True
131+
break
132+
assert_true(found)
133+
134+
alias L_l_s = List[List[String]]
135+
var i = L_l_s()
136+
var j = L_l_s()
137+
for x in range(10):
138+
i.append(L_s(str(x), str(x + 1), str(x + 3)))
139+
j.append(L_s(str(x), str(x + 1), str(x + 3)))
140+
shuffle(i)
141+
# TODO: Uncomment when possible
142+
# assert_true(g != h)
143+
assert_equal(len(i), len(j))
144+
for x in range(10):
145+
var target: List[String] = L_s(str(x), str(x + 1), str(x + 3))
146+
var found = False
147+
for y in range(len(i)):
148+
if j[y] == target:
149+
found = True
150+
break
151+
assert_true(found)
152+
153+
# Given the number of permutations of size 1000 is 1000!,
154+
# we rely on the assertion that a truly random shuffle should not
155+
# result in the same order as the to pre-shuffle list with extremely
156+
# high probability.
157+
var l = L_i()
158+
var m = L_i()
159+
for i in range(1000):
160+
l.append(i)
161+
m.append(i)
162+
shuffle(l)
163+
assert_equal(len(l), len(m))
164+
assert_true(l != m)
165+
shuffle(m)
166+
assert_equal(len(l), len(m))
167+
assert_true(l != m)
168+
169+
76170
def main():
77171
test_random()
78172
test_seed()
173+
test_shuffle()

0 commit comments

Comments
 (0)
0