8000 introduce `random.shuffle` for `List` · modular/modular@7789dee · GitHub
[go: up one dir, main page]

Skip to content

Commit 7789dee

Browse files
committed
introduce random.shuffle for List
Signed-off-by: Joshua James Venter <venter.joshua@gmail.com>
1 parent c17e096 commit 7789dee

File tree

4 files changed

+138
-2
lines changed

4 files changed

+138
-2
lines changed

docs/changelog.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,18 @@ future and `StringSlice.__len__` now does return the Unicode codepoints length.
350350
# "Mojo 'Mojo'"
351351
```
352352

353+
- Introduce `random.shuffle` for `List`.
354+
([PR #3327](https://github.com/modularml/mojo/pull/3327) by [@jjvraw](https://github.com/jjvraw))
355+
356+
Example:
357+
358+
```mojo
359+
from random import shuffle
360+
361+
var l = List[Int](1, 2, 3, 4, 5)
362+
shuffle(l)
363+
```
364+
353365
### 🦋 Changed
354366

355367
- The pointer aliasing semantics of Mojo have changed. Initially, Mojo adopted a

stdlib/src/random/__init__.mojo

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ from .random import (
2121
random_si64,
2222
random_ui64,
2323
seed,
24+
shuffle,
2425
)

stdlib/src/random/random.mojo

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ from random import seed
2121

2222
from sys import bitwidthof, external_call
2323
from time import perf_counter_ns
24-
2524
from memory import UnsafePointer
25+
from collections import List
2626

2727

2828
fn _get_random_state() -> UnsafePointer[NoneType]:
@@ -201,3 +201,28 @@ fn randn[
201201
for i in range(size):
202202
ptr[i] = randn_float64(mean, variance).cast[type]()
203203
return
204+
205+
206+
fn shuffle[T: CollectionElement](inout list_in: List[T]):
207+
"""Shuffles the elements of the list randomly.
208+
209+
Performs an in-place Fisher-Yates shuffle on the provided list.
210+
211+
Args:
212+
list_in: The list to modify.
213+
214+
Parameters:
215+
T: The type of element in the List.
216+
"""
217+
218+
var length = len(list_in)
219+
220+
for i in range(length - 1, 0, -1):
221+
var j: Int = random_ui64(0, i).cast[DType.int64]().value
222+
223+
var i_ptr = list_in.data.offset(i)
224+
var j_ptr = list_in.data.offset(j)
225+
226+
var tmp = i_ptr.take_pointee()
227+
j_ptr.move_pointee_into(i_ptr)
228+
j_ptr.init_pointee_move(tmp^)

stdlib/test/random/test_random.mojo

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,14 @@
1212
# ===----------------------------------------------------------------------=== #
1313
# RUN: %mojo %s
1414

15-
from random import randn_float64, random_float64, random_si64, random_ui64, seed
15+
from random import (
16+
randn_float64,
17+
random_float64,
18+
random_si64,
19+
random_ui64,
20+
seed,
21+
shuffle,
22+
)
1623

1724
from testing import assert_equal, assert_true
1825

@@ -73,6 +80,97 @@ def test_seed():
7380
assert_equal(some_unsigned_integer, random_ui64(0, 255))
7481

7582

83+
def test_shuffle():
84+
# TODO: Clean up with list comprehension when possible.
85+
86+
# Deterministic Tests
87+
alias L_i = List[Int]
88+
alias L_s = List[String]
89+
var a = L_i(1, 2, 3, 4)
90+
var b = L_i(1, 2, 3, 4)
91+
var c = L_s("Random", "shuffle", "in", "Mojo")
92+
var d = L_s("Random", "shuffle", "in", "Mojo")
93+
94+
shuffle(b)
95+
assert_equal(len(a), len(b))
96+
assert_true(a != b)
97+
for i in range(len(b)):
98+
assert_true(a.__contains__(b[i]))
99+
100+
shuffle(d)
101+
assert_equal(len(c), len(d))
102+
assert_true(c != d)
103+
for i in range(len(d)):
104+
assert_true(c.__contains__(d[i]))
105+
106+
var e = L_i(21)
107+
shuffle(e)
108+
assert_true(e == L_i(21))
109+
var f = L_s("Mojo")
110+
shuffle(f)
111+
assert_true(f == L_s("Mojo"))
112+
113+
alias L_l = List[List[Int]]
114+
var g = L_l()
115+
var h = L_l()
116+
for i in range(10):
117+
g.append(L_i(i, i + 1, i + 3))
118+
h.append(L_i(i, i + 1, i + 3))
119+
shuffle(g)
120+
# TODO: Uncomment when possible
121+
# assert_true(g != h)
122+
assert_equal(len(g), len(h))
123+
for i in range(10):
124+
# Currently, the below does not compile.
125+
# Error: invalid call to '__contains__':
126+
# could not deduce positional-only parameter #1 of callee '__contains__'
127+
# assert_true(g.__contains__(L_i(i, i + 1, i + 3)))
128+
var target: List[Int] = L_i(i, i + 1, i + 3)
129+
var found = False
130+
for j in range(len(g)):
131+
if g[j] == target:
132+
found = True
133+
break
134+
assert_true(found)
135+
136+
alias L_l_s = List[List[String]]
137+
var i = L_l_s()
138+
var j = L_l_s()
139+
for x in range(10):
140+
i.append(L_s(str(x), str(x + 1), str(x + 3)))
141+
j.append(L_s(str(x), str(x + 1), str(x + 3)))
142+
shuffle(i)
143+
# TODO: Uncomment when possible
144+
# assert_true(g != h)
145+
assert_equal(len(i), len(j))
146+
for x in range(10):
147+
var target: List[String] = L_s(str(x), str(x + 1), str(x + 3))
148+
var found = False
149+
for y in range(len(i)):
150+
if j[y] == target:
151+
found = True
152+
break
153+
assert_true(found)
154+
155+
# Non-deteministic tests
156+
# Given the number of permutations of size 1000 is 1000!,
157+
# we rely on the assertion that a truly random shuffle should not
158+
# result in the same order as the to pre-shuffle list with extremely
159+
# high probability.
160+
var l = L_i()
161+
var m = L_i()
162+
for i in range(1000):
163+
l.append(i)
164+
m.append(i)
165+
shuffle(l)
166+
assert_equal(len(l), len(m))
167+
assert_true(l != m)
168+
shuffle(m)
169+
assert_equal(len(l), len(m))
170+
assert_true(l != m)
171+
172+
76173
def main():
77174
test_random()
78175
test_seed()
176+
test_shuffle()

0 commit comments

Comments
 (0)
0